More cleanup in mlir-hlo to prepare for the standalone build
Shuffle files around, use TableGen to register passes, and introduce a `mlir-hlo-opt.cpp` file to hold the main entry point of the -opt tool and stop relying on static registration for dialect/passes. PiperOrigin-RevId: 323674455
This commit is contained in:
parent
effd3fb4f9
commit
cd01bb4c4e
|
@ -16,16 +16,16 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace chlo {
|
||||
|
@ -37,7 +37,7 @@ class HloClientDialect : public Dialect {
|
|||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
|
||||
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -29,9 +29,9 @@ limitations under the License.
|
|||
#ifndef CHLO_OPS
|
||||
#define CHLO_OPS
|
||||
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
|
||||
def HLOClient_Dialect : Dialect {
|
||||
|
|
|
@ -18,24 +18,24 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
namespace mlir {
|
||||
class OpBuilder;
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc"
|
||||
|
||||
namespace mhlo {
|
||||
|
||||
|
@ -91,7 +91,7 @@ LogicalResult deriveShapeFromFirstOperand(
|
|||
SmallVectorImpl<Value> *reifiedReturnShapes);
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
|
||||
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -18,12 +18,12 @@ limitations under the License.
|
|||
#ifndef HLO_OPS
|
||||
#define HLO_OPS
|
||||
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
||||
|
||||
def HLO_Dialect : Dialect {
|
||||
let name = "mhlo";
|
||||
|
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||
#ifndef HLO_OPS_BASE
|
||||
#define HLO_OPS_BASE
|
||||
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||
#ifndef HLO_UTILS
|
||||
#define HLO_UTILS
|
||||
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||
|
||||
namespace mlir {
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||
#ifndef MLIR_INFER_FUSIBILITY_OP_INTERFACE
|
||||
#define MLIR_INFER_FUSIBILITY_OP_INTERFACE
|
||||
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
// OpInterface to query if an op is fusible and to query the shape equality
|
||||
// constraint among the inputs and outputs of an op.
|
||||
|
|
|
@ -18,22 +18,22 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
class OpBuilder;
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"
|
||||
|
||||
namespace lmhlo {
|
||||
|
||||
|
@ -44,7 +44,7 @@ class LmhloDialect : public Dialect {
|
|||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||
|
||||
} // namespace lmhlo
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -33,9 +33,9 @@ limitations under the License.
|
|||
#ifndef LHLO_OPS
|
||||
#define LHLO_OPS
|
||||
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
|
||||
def LHLO_Dialect : Dialect {
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_
|
||||
#define MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
||||
void registerAllDialects();
|
||||
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_
|
|
@ -0,0 +1,65 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def LhloCopyRemovalPass : Pass<"lhlo-copy-removal", "FuncOp"> {
|
||||
let summary = "Removes redundant LHLO copy operations.";
|
||||
let constructor = "createLhloCopyRemovalPass()";
|
||||
}
|
||||
|
||||
|
||||
def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> {
|
||||
let summary = "Legalize from LHLO dialect to Linalg dialect.";
|
||||
let constructor = "createLegalizeLhloToLinalgPass()";
|
||||
}
|
||||
|
||||
|
||||
def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> {
|
||||
let summary = "Greedily fuse linalg ops obtained after LHLO lowering.";
|
||||
let constructor = "createLhloFuseLinalgPass()";
|
||||
let options = [
|
||||
Option<"use_parallel_loops_", "use-parallel-loops", "bool",
|
||||
/*default=*/"false", "Tiles GenericOp consumer to parallel loops before linalg fusion">,
|
||||
ListOption<"tile_sizes_", "tile-sizes", "unsigned",
|
||||
"Faster memory space number to promote fusion buffers to",
|
||||
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
def LhloLegalizeToAffinePass : Pass<"lhlo-legalize-to-affine", "FuncOp"> {
|
||||
let summary = "Legalize from LHLO dialect to affine dialect.";
|
||||
let constructor = "createLhloLegalizeToAffinePass()";
|
||||
}
|
||||
|
||||
|
||||
def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> {
|
||||
let summary = "Legalize from LHLO dialect to GPU dialect.";
|
||||
let constructor = "createLegalizeToGpuPass()";
|
||||
}
|
||||
|
||||
|
||||
def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> {
|
||||
let summary = "Legalize from LHLO dialect to LLVM.";
|
||||
let constructor = "createTestLhloToLLVMPass()";
|
||||
}
|
||||
|
||||
|
||||
def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> {
|
||||
let summary = "Legalize from LHLO dialect to parallel loops.";
|
||||
let constructor = "createLegalizeLhloToParallelLoopsPass()";
|
||||
}
|
||||
|
|
@ -18,8 +18,8 @@ limitations under the License.
|
|||
|
||||
#include <type_traits>
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
|
|
@ -16,12 +16,12 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> {
|
||||
let summary = "Test pass for applying chlo -> hlo legalization patterns.";
|
||||
let constructor = "createTestChloLegalizeToHloPass()";
|
||||
}
|
||||
|
||||
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
||||
let summary = "Legalize from HLO dialect to LHLO dialect.";
|
||||
let constructor = "createLegalizeToLhloPass()";
|
||||
}
|
||||
|
||||
def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {
|
||||
let summary = "Legalize from MHLO control flow to CFG control flow.";
|
||||
let constructor = "createLegalizeControlFlowPass()";
|
||||
}
|
||||
|
||||
def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> {
|
||||
let summary = "Legalizes gathers to a torch index select.";
|
||||
let constructor = "createLegalizeGatherToTorchIndexSelectPass()";
|
||||
}
|
||||
|
||||
|
||||
def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-tanh-to-approximation", "FuncOp"> {
|
||||
let summary = "Legalize tanh from standard dialect to an approximation.";
|
||||
let constructor = "createLegalizeTanhToApproximationPass()";
|
||||
}
|
||||
|
||||
|
||||
def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "FuncOp"> {
|
||||
let summary = "Legalize from HLO dialect to Linalg dialect.";
|
||||
let constructor = "createLegalizeHloToLinalgPass()";
|
||||
}
|
||||
|
||||
|
||||
def LegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "FuncOp"> {
|
||||
let summary = "Legalize from MHLO dialect to standard dialect.";
|
||||
let constructor = "createLegalizeToStdPass()";
|
||||
}
|
||||
|
||||
def LowerComplexPass : Pass<"mhlo-test-lower-complex", "FuncOp"> {
|
||||
let summary = "Lower complex operations into non-complex operations.";
|
||||
let constructor = "createLowerComplexPass()";
|
||||
}
|
||||
|
||||
|
||||
def LegalizeGeneralDotPass : Pass<"mhlo-test-lower-general-dot", "FuncOp"> {
|
||||
let summary = "Tests lowering general dot to a non-batched dot when possible.";
|
||||
let constructor = "createLegalizeGeneralDotPass()";
|
||||
}
|
||||
|
||||
|
||||
def TestMaterializeBroadcastsPass : Pass<"mhlo-test-materialize-broadcasts", "FuncOp"> {
|
||||
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
|
||||
let constructor = "createTestMaterializeBroadcastsPass()";
|
||||
}
|
||||
|
||||
|
||||
def MhloFusionPass : Pass<"mhlo-fusion", "FuncOp"> {
|
||||
let summary = "Fuse mhlo ops to kLoop/kInput fusion patterns.";
|
||||
let constructor = "createMhloFusionPass()";
|
||||
}
|
||||
|
||||
|
||||
def OptimizeMhloPass : Pass<"mhlo-test-optimize", "FuncOp"> {
|
||||
let summary = "Run optional HLO optimizations.";
|
||||
let constructor = "createOptimizeMhloPass()";
|
||||
}
|
||||
|
||||
|
||||
def SinkConstantsToControlFlowPass : Pass<"mhlo-sink-constants-to-control-flow", "FuncOp"> {
|
||||
let summary = "Sink constants implicitly captured in control flow regions. This "
|
||||
"is necessary to export to XLA.";
|
||||
let constructor = "createSinkConstantsToControlFlowPass()";
|
||||
}
|
||||
|
||||
|
||||
def TestInferShapedTypeMethodsPass : Pass<"mhlo-test-infer-shaped-type-methods", "FuncOp"> {
|
||||
let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods.";
|
||||
let constructor = "createTestInferShapedTypeMethodsPass()";
|
||||
}
|
||||
|
||||
|
||||
def TransformUnrankedHloPass : Pass<"transform-unranked-hlo", "FuncOp"> {
|
||||
let summary = "Realize element-wise operations on ranked tensors where possible.";
|
||||
let constructor = "createTransformUnrankedHloPass()";
|
||||
}
|
||||
|
||||
|
||||
def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
|
||||
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
|
||||
let constructor = "createTestUnfuseBatchNormPass()";
|
||||
}
|
|
@ -18,11 +18,12 @@ limitations under the License.
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class FuncOp;
|
||||
class FunctionPass;
|
||||
class ModuleOp;
|
||||
class Operation;
|
||||
template <typename T>
|
||||
|
@ -58,18 +59,26 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
|||
// fuse mhlo ops to kLoop/kInput fusion patterns
|
||||
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
||||
|
||||
/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
|
||||
|
||||
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
|
||||
std::unique_ptr<FunctionPass> createLowerComplexPass();
|
||||
std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass();
|
||||
std::unique_ptr<FunctionPass> createLegalizeGatherToTorchIndexSelectPass();
|
||||
|
||||
} // namespace mhlo
|
||||
|
||||
namespace lmhlo {
|
||||
|
||||
// Lowers from LHLO dialect to Affine dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLhloLegalizeToAffinePass();
|
||||
|
||||
// Lowers from LHLO dialect to Linalg dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass();
|
||||
|
||||
// Lowers from LHLO dialect to GPU dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass();
|
||||
std::unique_ptr<FunctionPass> createLegalizeToGpuPass();
|
||||
|
||||
// Fuses linalg ops obtained after LHLO lowering. To enable fusion,
|
||||
// operations are first tiled.
|
||||
|
@ -80,7 +89,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass();
|
|||
// 'tile_sizes' provides the tile sizes to use for tiling. If the linalg
|
||||
// operation has more dimensions than tile sizes provided, 1 is used as
|
||||
// default.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg(
|
||||
std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
|
||||
bool use_parallel_loops = false, llvm::ArrayRef<unsigned> tile_sizes = {});
|
||||
|
||||
// Removes unnecessary LHLO copies which copy from the allocated buffers to the
|
||||
|
@ -94,12 +103,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
|||
|
||||
} // namespace lmhlo
|
||||
|
||||
namespace hlo {
|
||||
|
||||
/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_
|
||||
#define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
||||
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
|
||||
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
|
||||
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
|
||||
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();
|
||||
|
||||
inline void registerAllMhloPasses() {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc"
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
|
||||
namespace lmhlo {
|
||||
|
||||
std::unique_ptr<Pass> createTestLhloToLLVMPass();
|
||||
|
||||
inline void registerAllLmhloPasses() {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"
|
||||
}
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_
|
|
@ -18,9 +18,9 @@ limitations under the License.
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
class LLVMTypeConverter;
|
||||
|
@ -80,6 +80,11 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
|||
void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
// Populates a pattern that translates the standard TanhOp to an approximation
|
||||
// that does not use intrinsics.
|
||||
void PopulateTanhToApproximationPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace mhlo
|
||||
|
||||
namespace lmhlo {
|
||||
|
@ -100,14 +105,6 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
|||
|
||||
} // namespace chlo
|
||||
|
||||
namespace hlo {
|
||||
|
||||
// Populates a pattern that translates the standard TanhOp to an approximation
|
||||
// that does not use intrinsics.
|
||||
void PopulateTanhToApproximationPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
|
||||
|
|
|
@ -19,12 +19,12 @@ limitations under the License.
|
|||
// Utilities relating to implementing HLO broadcasting.
|
||||
// Note: This file should not depend on any non-MLIR TensorFlow libraries.
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace hlo {
|
||||
|
|
|
@ -16,8 +16,8 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace hlo {
|
||||
|
|
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
|
|
@ -16,11 +16,11 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace hlo {
|
||||
|
|
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||
#include "mlir-hlo/utils/broadcast_utils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace chlo {
|
||||
|
@ -260,7 +260,7 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
|
|||
#undef BROADCAST_BINARY_OP_DEFS
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// chlo Dialect Constructor
|
||||
|
@ -270,7 +270,7 @@ HloClientDialect::HloClientDialect(MLIRContext* context)
|
|||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
|
|
|
@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
// Static initialization for *HLO dialects registration.
|
||||
static mlir::DialectRegistration<mlir::mhlo::MhloDialect> mhlo_ops;
|
||||
|
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||
|
||||
// This file defines the operations used in the MHLO dialect.
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stddef.h>
|
||||
|
@ -24,44 +24,43 @@ limitations under the License.
|
|||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "third_party/absl/container/flat_hash_set.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APFloat.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APInt.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/iterator_range.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/MathExtras.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Matchers.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpImplementation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/InliningUtils.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
|
||||
#include "mlir-hlo/utils/convert_op_folder.h"
|
||||
#include "mlir-hlo/utils/hlo_utils.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_patterns.cc.inc"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
|
||||
#include "hlo_patterns.cc.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"
|
||||
namespace mhlo {
|
||||
|
||||
Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
|
||||
|
@ -106,7 +105,7 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
|
|||
return GetI64ElementsAttr(slice_limits, builder);
|
||||
}
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_canonicalize.inc"
|
||||
#include "mhlo_canonicalize.inc"
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -375,8 +374,8 @@ static LogicalResult Verify(CollectivePermuteOp op) {
|
|||
<< "expect source_target_pairs attribute of shape (N, 2), but got ("
|
||||
<< type.getShape() << ")";
|
||||
// Check source target pairs for duplicate sources or targets
|
||||
absl::flat_hash_set<int64_t> sources;
|
||||
absl::flat_hash_set<int64_t> targets;
|
||||
llvm::DenseSet<int64_t> sources;
|
||||
llvm::DenseSet<int64_t> targets;
|
||||
for (auto i = op.source_target_pairs().begin(),
|
||||
e = op.source_target_pairs().end();
|
||||
i != e; ++i) {
|
||||
|
@ -2123,7 +2122,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
|
|||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mhlo Dialect Interfaces
|
||||
|
@ -2154,7 +2153,7 @@ MhloDialect::MhloDialect(MLIRContext* context)
|
|||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
|
||||
>();
|
||||
addInterfaces<HLOInlinerInterface>();
|
||||
addTypes<TokenType>();
|
||||
|
|
|
@ -15,8 +15,8 @@ limitations under the License.
|
|||
|
||||
// Canonicalization patterns for the MHLO dialect.
|
||||
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td"
|
||||
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
include "mlir/Dialect/Shape/IR/ShapeOps.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
|
||||
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
|
||||
|
||||
|
|
|
@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cc.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cpp.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
|
||||
|
||||
// Static initialization for *HLO dialects registration.
|
||||
|
||||
void mlir::mhlo::registerAllDialects() {
|
||||
static bool init_once = []() {
|
||||
registerDialect<mlir::chlo::HloClientDialect>();
|
||||
registerDialect<mlir::lmhlo::LmhloDialect>();
|
||||
registerDialect<mlir::mhlo::MhloDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
|
||||
// Dependent dialects
|
||||
}
|
|
@ -15,44 +15,44 @@ limitations under the License.
|
|||
|
||||
// This file defines the operations used in the LMHLO dialect.
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APFloat.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APInt.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpImplementation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"
|
||||
namespace lmhlo {
|
||||
|
||||
LmhloDialect::LmhloDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
|
@ -127,7 +127,7 @@ static LogicalResult Verify(ReshapeMemRefCastOp op) {
|
|||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
|
||||
|
||||
// TODO(cheshire): Support folding, reuse code from hlo_ops.cc.
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||
|
||||
// This is the canonicalize pattern definition file.
|
||||
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||
|
|
@ -13,19 +13,19 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/utils/broadcast_utils.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace chlo {
|
||||
|
|
|
@ -13,16 +13,17 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace chlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -32,7 +33,7 @@ struct TestChloLegalizeToHloPass
|
|||
ConversionTarget conversionTarget(getContext());
|
||||
OwningRewritePatternList conversionPatterns;
|
||||
|
||||
conversionTarget.addIllegalDialect<HloClientDialect>();
|
||||
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
|
@ -40,7 +41,7 @@ struct TestChloLegalizeToHloPass
|
|||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
|
||||
|
||||
PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
|
||||
chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
|
||||
|
||||
if (failed(applyPartialConversion(getFunction(), conversionTarget,
|
||||
conversionPatterns))) {
|
||||
|
@ -51,9 +52,10 @@ struct TestChloLegalizeToHloPass
|
|||
|
||||
} // namespace
|
||||
|
||||
} // namespace chlo
|
||||
std::unique_ptr<FunctionPass> createTestChloLegalizeToHloPass() {
|
||||
return std::make_unique<TestChloLegalizeToHloPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::chlo::TestChloLegalizeToHloPass> pass(
|
||||
"mhlo-test-chlo-legalize-to-hlo",
|
||||
"Test pass for applying chlo -> hlo legalization patterns");
|
||||
|
|
|
@ -15,26 +15,25 @@ limitations under the License.
|
|||
|
||||
// This file implements logic for lowering HLO dialect to LHLO dialect.
|
||||
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineMap.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/BufferPlacement.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/BufferPlacement.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
@ -511,11 +510,8 @@ void populateHLOToLHLOConversionPattern(
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||
bool results_escape_function) {
|
||||
return absl::make_unique<HloLegalizeToLhlo>(results_escape_function);
|
||||
return std::make_unique<HloLegalizeToLhlo>(results_escape_function);
|
||||
}
|
||||
|
||||
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
|
||||
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -15,30 +15,30 @@ limitations under the License.
|
|||
|
||||
// This file implements logic for lowering MHLO dialect to Standard dialect.
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Block.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
using mlir::PassRegistration;
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
struct LegalizeControlFlow
|
||||
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
|
||||
struct LegalizeControlFlowPass
|
||||
: public mlir::PassWrapper<LegalizeControlFlowPass, FunctionPass> {
|
||||
// Perform the lowering to MLIR control flow.
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
@ -206,7 +206,7 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
void LegalizeControlFlow::runOnFunction() {
|
||||
void LegalizeControlFlowPass::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
llvm::SmallVector<IfOp, 4> if_ops;
|
||||
func.walk([&](IfOp op) { if_ops.push_back(op); });
|
||||
|
@ -228,9 +228,5 @@ void LegalizeControlFlow::runOnFunction() {
|
|||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
|
||||
mlir::mhlo::createLegalizeControlFlowPass() {
|
||||
return std::make_unique<LegalizeControlFlow>();
|
||||
return std::make_unique<LegalizeControlFlowPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass(
|
||||
"mhlo-legalize-control-flow",
|
||||
"Legalize from MHLO control flow to CFG control flow");
|
||||
|
|
|
@ -13,13 +13,12 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
@ -128,8 +127,8 @@ struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct LegalizeGatherToTorchIndexSelect
|
||||
: public PassWrapper<LegalizeGatherToTorchIndexSelect, FunctionPass> {
|
||||
struct LegalizeGatherToTorchIndexSelectPass
|
||||
: public PassWrapper<LegalizeGatherToTorchIndexSelectPass, FunctionPass> {
|
||||
/// Perform the lowering of standard dialect operations to approximations.
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
|
@ -144,9 +143,9 @@ void PopulateGatherToTorchIndexSelectPatterns(
|
|||
patterns->insert<GatherIsTorchIndexSelect>(context);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeGatherToTorchIndexSelect> legalize_hlo_pass(
|
||||
"mhlo-legalize-gather-to-torch-index-select",
|
||||
"Legalizes gathers to a torch index select.");
|
||||
std::unique_ptr<FunctionPass> createLegalizeGatherToTorchIndexSelectPass() {
|
||||
return std::make_unique<LegalizeGatherToTorchIndexSelectPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -16,15 +16,15 @@ limitations under the License.
|
|||
// This file implements logic for lowering the tanh standard ops to an
|
||||
// approximation.
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
/// Emits the fast tanh approximation that is also used by XLA.
|
||||
|
@ -126,8 +126,8 @@ class ApproximateTanhLowering : public OpRewritePattern<TanhOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct LegalizeTanhToApproximation
|
||||
: public PassWrapper<LegalizeTanhToApproximation, FunctionPass> {
|
||||
struct LegalizeTanhToApproximationPass
|
||||
: public PassWrapper<LegalizeTanhToApproximationPass, FunctionPass> {
|
||||
/// Perform the lowering of standard dialect operations to approximations.
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
|
@ -140,7 +140,7 @@ struct LegalizeTanhToApproximation
|
|||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
|
||||
createLegalizeTanhToApproximationPass() {
|
||||
return std::make_unique<LegalizeTanhToApproximation>();
|
||||
return std::make_unique<LegalizeTanhToApproximationPass>();
|
||||
}
|
||||
|
||||
void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
|
||||
|
@ -148,9 +148,5 @@ void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
|
|||
patterns->insert<ApproximateTanhLowering>(context);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeTanhToApproximation> legalize_pass(
|
||||
"mhlo-legalize-tanh-to-approximation",
|
||||
"Legalize tanh from standard dialect to an approximation");
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -15,26 +15,25 @@ limitations under the License.
|
|||
|
||||
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
||||
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineExpr.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
@ -826,8 +825,8 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
// indexing_maps = [#map0, #map0, #map0],
|
||||
// iterator_types = ["parallel", "parallel"],
|
||||
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
struct LhloLegalizeToLinalg
|
||||
: public PassWrapper<LhloLegalizeToLinalg, FunctionPass> {
|
||||
struct LhloLegalizeToLinalgPass
|
||||
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
|
@ -842,8 +841,8 @@ struct LhloLegalizeToLinalg
|
|||
}
|
||||
};
|
||||
|
||||
struct HloLegalizeToLinalg
|
||||
: public PassWrapper<HloLegalizeToLinalg, FunctionPass> {
|
||||
struct HloLegalizeToLinalgPass
|
||||
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
|
@ -861,11 +860,8 @@ struct HloLegalizeToLinalg
|
|||
|
||||
namespace lmhlo {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
|
||||
return absl::make_unique<LhloLegalizeToLinalg>();
|
||||
return std::make_unique<LhloLegalizeToLinalgPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
|
||||
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
|
||||
} // namespace lmhlo
|
||||
|
||||
namespace mhlo {
|
||||
|
@ -906,10 +902,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
return absl::make_unique<HloLegalizeToLinalg>();
|
||||
return std::make_unique<HloLegalizeToLinalgPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass(
|
||||
"hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect");
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -15,18 +15,18 @@ limitations under the License.
|
|||
|
||||
// This file implements logic for lowering MHLO dialect to Standard dialect.
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace {
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"
|
||||
#include "generated_legalize_to_standard.inc"
|
||||
} // end anonymous namespace
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
@ -176,15 +176,15 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
|
|||
} // end anonymous namespace
|
||||
|
||||
namespace {
|
||||
struct LegalizeToStandard
|
||||
: public PassWrapper<LegalizeToStandard, FunctionPass> {
|
||||
struct LegalizeToStandardPass
|
||||
: public PassWrapper<LegalizeToStandardPass, FunctionPass> {
|
||||
/// Perform the lowering to Standard dialect.
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
|
||||
return std::make_unique<LegalizeToStandard>();
|
||||
return std::make_unique<LegalizeToStandardPass>();
|
||||
}
|
||||
|
||||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
||||
|
@ -194,14 +194,11 @@ void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
|||
}
|
||||
|
||||
/// Perform the lowering to standard dialect.
|
||||
void LegalizeToStandard::runOnFunction() {
|
||||
void LegalizeToStandardPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeToStandard> legalize_pass(
|
||||
"mhlo-legalize-to-std", "Legalize from MHLO dialect to standard dialect");
|
||||
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -15,9 +15,9 @@ limitations under the License.
|
|||
|
||||
// This is the legalization pattern definition file for MHLO to StandardOps.
|
||||
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Nullary op patterns.
|
||||
|
|
|
@ -15,12 +15,11 @@ limitations under the License.
|
|||
|
||||
// This file implements a pass to remove redundant LHLO copy operations.
|
||||
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
@ -30,7 +29,8 @@ namespace {
|
|||
// arguments. All uses of each buffer are replaced with the corresponding block
|
||||
// argument and the buffer is freed. Note that this pass only works in regions
|
||||
// with a single block.
|
||||
struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
|
||||
struct LhloCopyRemovalPass
|
||||
: mlir::PassWrapper<LhloCopyRemovalPass, OperationPass<>> {
|
||||
void runOnOperation() override {
|
||||
llvm::SmallVector<mlir::Operation*, 2> eraseList;
|
||||
auto operation = getOperation();
|
||||
|
@ -95,11 +95,8 @@ struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createLhloCopyRemovalPass() {
|
||||
return absl::make_unique<LhloCopyRemoval>();
|
||||
return std::make_unique<LhloCopyRemovalPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<LhloCopyRemoval> copy_removal_pass(
|
||||
"lhlo-copy-removal", "Removes redundant LHLO copy operations");
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -16,15 +16,14 @@ limitations under the License.
|
|||
// This file implements logic for fusing linalg ops obtained after LHLO
|
||||
// lowering.
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/FoldUtils.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
@ -32,11 +31,13 @@ namespace {
|
|||
|
||||
using linalg::LinalgOp;
|
||||
|
||||
class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
||||
class LhloFuseLinalgPass
|
||||
: public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
|
||||
public:
|
||||
LhloFuseLinalg() = default;
|
||||
LhloFuseLinalg(const LhloFuseLinalg&) {}
|
||||
LhloFuseLinalg(bool use_parallel_loops, llvm::ArrayRef<unsigned> tile_sizes) {
|
||||
LhloFuseLinalgPass() = default;
|
||||
LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
|
||||
LhloFuseLinalgPass(bool use_parallel_loops,
|
||||
llvm::ArrayRef<unsigned> tile_sizes) {
|
||||
tile_sizes_ = tile_sizes;
|
||||
use_parallel_loops_.setValue(use_parallel_loops);
|
||||
}
|
||||
|
@ -138,14 +139,10 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
|||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg(
|
||||
std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
|
||||
bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
|
||||
return absl::make_unique<LhloFuseLinalg>(use_parallel_loops, tile_sizes);
|
||||
return std::make_unique<LhloFuseLinalgPass>(use_parallel_loops, tile_sizes);
|
||||
}
|
||||
|
||||
static PassRegistration<LhloFuseLinalg> legalize_pass(
|
||||
"lhlo-fuse-linalg",
|
||||
"Greedily fuse linalg ops obtained after LHLO lowering.");
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -15,17 +15,16 @@ limitations under the License.
|
|||
|
||||
// This file implements logic for lowering LHLO dialect to Affine dialect.
|
||||
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
@ -138,8 +137,8 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
|
|||
// clang-format on
|
||||
}
|
||||
|
||||
struct LhloLegalizeToAffine
|
||||
: public PassWrapper<LhloLegalizeToAffine, FunctionPass> {
|
||||
struct LhloLegalizeToAffinePass
|
||||
: public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
|
@ -150,12 +149,9 @@ struct LhloLegalizeToAffine
|
|||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() {
|
||||
return absl::make_unique<LhloLegalizeToAffine>();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLhloLegalizeToAffinePass() {
|
||||
return std::make_unique<LhloLegalizeToAffinePass>();
|
||||
}
|
||||
|
||||
static PassRegistration<LhloLegalizeToAffine> legalize_pass(
|
||||
"lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect");
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -17,25 +17,24 @@ limitations under the License.
|
|||
|
||||
#include <cstdint>
|
||||
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
@ -168,7 +167,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
|||
};
|
||||
};
|
||||
|
||||
struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
|
||||
struct LhloLegalizeToGpuPass
|
||||
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
|
@ -185,12 +185,9 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
|
|||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() {
|
||||
return absl::make_unique<LhloLegalizeToGpu>();
|
||||
std::unique_ptr<FunctionPass> createLegalizeToGpuPass() {
|
||||
return std::make_unique<LhloLegalizeToGpuPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<LhloLegalizeToGpu> legalize_pass(
|
||||
"lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect");
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
|
|
@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
@ -57,8 +57,9 @@ class TestLhloToLLVMPass
|
|||
|
||||
} // namespace
|
||||
|
||||
static PassRegistration<TestLhloToLLVMPass> legalize_lhlo_pass(
|
||||
"test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM.");
|
||||
std::unique_ptr<Pass> createTestLhloToLLVMPass() {
|
||||
return std::make_unique<TestLhloToLLVMPass>();
|
||||
}
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -13,17 +13,16 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
@ -690,8 +689,8 @@ class SelectAndScatterOpConverter
|
|||
}
|
||||
};
|
||||
|
||||
struct LhloLegalizeToParallelLoops
|
||||
: public PassWrapper<LhloLegalizeToParallelLoops, FunctionPass> {
|
||||
struct LhloLegalizeToParallelLoopsPass
|
||||
: public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
auto func = getFunction();
|
||||
|
||||
|
@ -715,16 +714,11 @@ struct LhloLegalizeToParallelLoops
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
|
||||
return absl::make_unique<LhloLegalizeToParallelLoops>();
|
||||
return std::make_unique<LhloLegalizeToParallelLoopsPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<LhloLegalizeToParallelLoops> legalize_lhlo_pass(
|
||||
"lhlo-legalize-to-parallel-loops",
|
||||
"Legalize from LHLO dialect to parallel loops.");
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -22,18 +22,18 @@ limitations under the License.
|
|||
#include <iterator>
|
||||
#include <numeric>
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/utils/hlo_utils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
|
||||
using mlir::FunctionPass;
|
||||
using mlir::OwningRewritePatternList;
|
||||
|
@ -41,9 +41,9 @@ using mlir::PassRegistration;
|
|||
using mlir::PassWrapper;
|
||||
|
||||
namespace {
|
||||
class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> {
|
||||
class LowerComplexPass : public PassWrapper<LowerComplexPass, FunctionPass> {
|
||||
public:
|
||||
explicit LowerComplex() : PassWrapper<LowerComplex, FunctionPass>() {}
|
||||
explicit LowerComplexPass() : PassWrapper<LowerComplexPass, FunctionPass>() {}
|
||||
|
||||
/// Performs the lowering to MHLO dialect.
|
||||
void runOnFunction() override;
|
||||
|
@ -54,7 +54,7 @@ namespace mlir {
|
|||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc"
|
||||
#include "generated_lower_complex.inc"
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -66,7 +66,7 @@ void PopulateComplexLoweringPatterns(MLIRContext* context,
|
|||
} // end namespace mlir
|
||||
|
||||
// Lowers the complex operations that can be represented using other operations.
|
||||
void LowerComplex::runOnFunction() {
|
||||
void LowerComplexPass::runOnFunction() {
|
||||
// Add lowering patterns to the list.
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::mhlo::PopulateComplexLoweringPatterns(&getContext(), &patterns);
|
||||
|
@ -74,6 +74,6 @@ void LowerComplex::runOnFunction() {
|
|||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<LowerComplex> pass(
|
||||
"mhlo-test-lower-complex",
|
||||
"Lower complex operations into non-complex operations");
|
||||
std::unique_ptr<FunctionPass> mlir::mhlo::createLowerComplexPass() {
|
||||
return std::make_unique<LowerComplexPass>();
|
||||
}
|
||||
|
|
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||
// This is the legalization pattern that converts complex operations into
|
||||
// equivalent real value operations.
|
||||
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary op patterns.
|
||||
|
|
|
@ -15,20 +15,20 @@ limitations under the License.
|
|||
|
||||
// This file implements logic for lowering MHLO general dot to a regular dot.
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using mlir::DenseIntElementsAttr;
|
||||
using mlir::ElementsAttr;
|
||||
|
@ -170,8 +170,8 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct LegalizeGeneralDot
|
||||
: public PassWrapper<LegalizeGeneralDot, FunctionPass> {
|
||||
struct LegalizeGeneralDotPass
|
||||
: public PassWrapper<LegalizeGeneralDotPass, FunctionPass> {
|
||||
/// Lower all general dots that can be represented as a non-batched matmul.
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
|
@ -187,6 +187,6 @@ void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
|
|||
patterns->insert<GeneralDotConvert>(ctx);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeGeneralDot> legalize_pass(
|
||||
"mhlo-test-lower-general-dot",
|
||||
"Tests lowering general dot to a non-batched dot when possible");
|
||||
std::unique_ptr<::mlir::Pass> mlir::mhlo::createLegalizeGeneralDotPass() {
|
||||
return std::make_unique<LegalizeGeneralDotPass>();
|
||||
}
|
||||
|
|
|
@ -15,12 +15,12 @@ limitations under the License.
|
|||
|
||||
#include <numeric>
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
|
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass
|
|||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<::mlir::Pass> createTestMaterializeBroadcastsPass() {
|
||||
return std::make_unique<TestMaterializeBroadcastsPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass(
|
||||
"mhlo-test-materialize-broadcasts",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
|
|
|
@ -18,14 +18,14 @@ limitations under the License.
|
|||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/EquivalenceClasses.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/utils/cycle_detector.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/EquivalenceClasses.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
|
||||
|
||||
// This pass has similar functionality of the fusion pass in XLA stack.
|
||||
// However, unlike XLA, it targets the fully dynamic shape scenario.
|
||||
|
@ -479,7 +479,7 @@ class FusionPlanner {
|
|||
EquivalenceClasses<int32_t> leader_for_node_;
|
||||
};
|
||||
|
||||
struct MhloFusion : public mlir::PassWrapper<MhloFusion, FunctionPass> {
|
||||
struct MhloFusionPass : public mlir::PassWrapper<MhloFusionPass, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
FuncOp func = getFunction();
|
||||
if (!IsTargetFunc(func)) {
|
||||
|
@ -568,12 +568,9 @@ struct MhloFusion : public mlir::PassWrapper<MhloFusion, FunctionPass> {
|
|||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createMhloFusion() {
|
||||
return std::make_unique<MhloFusion>();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass() {
|
||||
return std::make_unique<MhloFusionPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<MhloFusion> mhlo_fusion_pass(
|
||||
"mhlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -21,18 +21,18 @@ limitations under the License.
|
|||
#include <iterator>
|
||||
#include <numeric>
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/utils/hlo_utils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
|
||||
using mlir::OwningRewritePatternList;
|
||||
|
||||
|
|
|
@ -13,23 +13,24 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using mlir::FunctionPass;
|
||||
using mlir::PassRegistration;
|
||||
using mlir::PassWrapper;
|
||||
|
||||
namespace {
|
||||
class OptimizeMhlo : public PassWrapper<OptimizeMhlo, FunctionPass> {
|
||||
class OptimizeMhloPass : public PassWrapper<OptimizeMhloPass, FunctionPass> {
|
||||
public:
|
||||
explicit OptimizeMhlo() : PassWrapper<OptimizeMhlo, FunctionPass>() {}
|
||||
explicit OptimizeMhloPass() : PassWrapper<OptimizeMhloPass, FunctionPass>() {}
|
||||
|
||||
/// Performs the lowering to MHLO dialect.
|
||||
void runOnFunction() override;
|
||||
|
@ -37,7 +38,7 @@ class OptimizeMhlo : public PassWrapper<OptimizeMhlo, FunctionPass> {
|
|||
} // end anonymous namespace
|
||||
|
||||
// Lowers the complex operations that can be represented using other operations.
|
||||
void OptimizeMhlo::runOnFunction() {
|
||||
void OptimizeMhloPass::runOnFunction() {
|
||||
// Add lowering patterns to the list.
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
|
||||
|
@ -45,5 +46,6 @@ void OptimizeMhlo::runOnFunction() {
|
|||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<OptimizeMhlo> pass("mhlo-test-optimize",
|
||||
"Run optional HLO optimizations.");
|
||||
std::unique_ptr<mlir::FunctionPass> mlir::mhlo::createOptimizeMhloPass() {
|
||||
return std::make_unique<OptimizeMhloPass>();
|
||||
}
|
||||
|
|
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassManager.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/RegionUtils.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
@ -29,8 +29,8 @@ namespace {
|
|||
|
||||
// A pass that sinks constants implicitly captured in control flow regions. This
|
||||
// is necessary to export to XLA.
|
||||
class SinkConstantsToControlFlow
|
||||
: public mlir::PassWrapper<SinkConstantsToControlFlow, FunctionPass> {
|
||||
class SinkConstantsToControlFlowPass
|
||||
: public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
getFunction().walk([](Operation* op) {
|
||||
if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
|
||||
|
@ -70,15 +70,10 @@ class SinkConstantsToControlFlow
|
|||
}
|
||||
};
|
||||
|
||||
static mlir::PassRegistration<SinkConstantsToControlFlow> pass(
|
||||
"mhlo-sink-constants-to-control-flow",
|
||||
"Sink constants implicitly captured in control flow regions. This is "
|
||||
"necessary to export to XLA.");
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
|
||||
return std::make_unique<SinkConstantsToControlFlow>();
|
||||
return std::make_unique<SinkConstantsToControlFlowPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
|
|
|
@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Identifier.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
struct InferReturnTypeComponentsPattern : public RewritePattern {
|
||||
|
@ -92,9 +92,10 @@ struct TestInferShapedTypeMethodsPass
|
|||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::hlo::TestInferShapedTypeMethodsPass> pass(
|
||||
"mhlo-test-infer-shaped-type-methods",
|
||||
"Uses test ops to invoke InferShapedTypeOpInterface methods");
|
||||
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass() {
|
||||
return std::make_unique<TestInferShapedTypeMethodsPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -14,18 +14,17 @@ limitations under the License.
|
|||
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
@ -204,9 +203,9 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
|||
// clang-format on
|
||||
}
|
||||
|
||||
static PassRegistration<TransformUnrankedHloPass> transform_unranked_hlo_pass(
|
||||
"transform-unranked-hlo",
|
||||
"Realize element-wise operations on ranked tensors where possible");
|
||||
std::unique_ptr<::mlir::Pass> createTransformUnrankedHloPass() {
|
||||
return std::make_unique<TransformUnrankedHloPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
|
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass
|
|||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<::mlir::Pass> createTestUnfuseBatchNormPass() {
|
||||
return std::make_unique<TestUnfuseBatchNormPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass(
|
||||
"mhlo-test-unfuse-batch-norm",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
|
|
|
@ -13,15 +13,15 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||
#include "mlir-hlo/utils/broadcast_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/Sequence.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace hlo {
|
||||
|
|
|
@ -15,11 +15,11 @@ limitations under the License.
|
|||
|
||||
// This file defines helpers useful when creating or manipulating lhlo/hlo.
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h"
|
||||
#include "mlir-hlo/utils/convert_op_folder.h"
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace hlo {
|
||||
|
|
|
@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
|
||||
#include "mlir-hlo/utils/cycle_detector.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
|
||||
#include "mlir-hlo/utils/cycle_detector.h"
|
||||
|
||||
#include "third_party/tensorflow/compiler/xla/test.h"
|
||||
|
||||
|
|
|
@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||
#include "mlir-hlo/utils/hlo_utils.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace hlo {
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<std::string> outputFilename(
|
||||
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> splitInputFile(
|
||||
"split-input-file",
|
||||
llvm::cl::desc("Split the input file into pieces and process each "
|
||||
"chunk independently"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> verifyDiagnostics(
|
||||
"verify-diagnostics",
|
||||
llvm::cl::desc("Check that emitted diagnostics match "
|
||||
"expected-* lines on the corresponding line"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> verifyPasses(
|
||||
"verify-each",
|
||||
llvm::cl::desc("Run the verifier after each transformation pass"),
|
||||
llvm::cl::init(true));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> allowUnregisteredDialects(
|
||||
"allow-unregistered-dialect",
|
||||
llvm::cl::desc("Allow operation with no registered dialects"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> showDialects(
|
||||
"show-dialects", llvm::cl::desc("Print the list of registered dialects"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::registerAllDialects();
|
||||
mlir::registerAllPasses();
|
||||
|
||||
mlir::mhlo::registerAllDialects();
|
||||
mlir::mhlo::registerAllMhloPasses();
|
||||
mlir::lmhlo::registerAllLmhloPasses();
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
// Register any pass manager command line options.
|
||||
mlir::registerPassManagerCLOptions();
|
||||
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
|
||||
|
||||
// Parse pass names in main to ensure static initialization completed.
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv,
|
||||
"MLIR modular optimizer driver\n");
|
||||
|
||||
if (showDialects) {
|
||||
mlir::MLIRContext context;
|
||||
llvm::outs() << "Registered Dialects:\n";
|
||||
for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
|
||||
llvm::outs() << dialect->getNamespace() << "\n";
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Set up the input file.
|
||||
std::string errorMessage;
|
||||
auto file = mlir::openInputFile(inputFilename, &errorMessage);
|
||||
if (!file) {
|
||||
llvm::errs() << errorMessage << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto output = mlir::openOutputFile(outputFilename, &errorMessage);
|
||||
if (!output) {
|
||||
llvm::errs() << errorMessage << "\n";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
|
||||
splitInputFile, verifyDiagnostics, verifyPasses,
|
||||
allowUnregisteredDialects))) {
|
||||
return 1;
|
||||
}
|
||||
// Keep the output file if the invocation of MlirOptMain was successful.
|
||||
output->keep();
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue