Cleanup build rule names in compiler/mlir/hlo to remove the redundant/obsolete xla_ prefix
PiperOrigin-RevId: 320320140
This commit is contained in:
parent
f4303855c4
commit
506ddd9c4a
|
@ -38,7 +38,7 @@ def HLOClient_Dialect : Dialect {
|
|||
let name = "chlo";
|
||||
let cppNamespace = "chlo";
|
||||
let summary = [{
|
||||
XLA Client HLO Ops
|
||||
Client HLO Ops
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
|
@ -60,7 +60,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA binary elementwise op definitions.
|
||||
// CHLO binary elementwise op definitions.
|
||||
// From the client perspective, each of these support both explicit rank
|
||||
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
|
||||
// shape broadcasting.
|
||||
|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the operations used in the XLA dialect.
|
||||
// This file defines the operations used in the MHLO dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||
|
|
|
@ -13,10 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the operation definition file for XLA HLO ops which map to the
|
||||
// traditional definition in xla_data.proto (or are aligned with the goals
|
||||
// thereof).
|
||||
// See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto
|
||||
// This is the operation definition file for MHLO ops.
|
||||
|
||||
#ifndef HLO_OPS
|
||||
#define HLO_OPS
|
||||
|
@ -44,7 +41,7 @@ class HLO_Op<string mnemonic, list<OpTrait> traits> :
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA nullary op definitions.
|
||||
// MHLO nullary op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_ConstOp : HLO_Op<"constant",
|
||||
|
@ -113,7 +110,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA unary elementwise op definitions.
|
||||
// MHLO unary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
||||
|
||||
|
@ -264,7 +261,7 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
|
|||
HLO_FpOrComplexTensor>, BASE_HLO_TanhOp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA binary elementwise op definitions.
|
||||
// MHLO binary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
||||
|
||||
|
@ -363,7 +360,7 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract",
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA binary logical elementwise op definitions.
|
||||
// MHLO binary logical elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
||||
|
@ -381,7 +378,7 @@ def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp;
|
|||
def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA communication op definitions.
|
||||
// MHLO communication op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'.
|
||||
|
@ -481,7 +478,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA parallelism related op definitions.
|
||||
// MHLO parallelism related op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
|
||||
|
@ -492,7 +489,7 @@ def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA control flow op definitions.
|
||||
// MHLO control flow op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> {
|
||||
|
@ -640,7 +637,7 @@ def HLO_ReduceOp: HLO_Op<"reduce", [
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA tuple op definitions.
|
||||
// MHLO tuple op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO_GetTupleElementOp {
|
||||
let arguments = (ins
|
||||
|
@ -684,7 +681,7 @@ def HLO_CompareOp: HLO_Op<"compare",
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA Slice definitions.
|
||||
// MHLO Slice definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_SliceOp: HLO_Op<
|
||||
|
@ -745,7 +742,7 @@ def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice",
|
|||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA Other op definitions.
|
||||
// MHLO Other op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>,
|
||||
|
@ -1320,7 +1317,7 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA RngUniform Operator.
|
||||
// MHLO RngUniform Operator.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
|
||||
let arguments = (ins
|
||||
|
@ -1347,7 +1344,7 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA Quantize Operator.
|
||||
// MHLO Quantize Operator.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>,
|
||||
BASE_HLO_DequantizeOp {
|
||||
|
|
|
@ -35,7 +35,7 @@ def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
|
|||
defvar BroadcastDimAttr = I64ElementsAttr;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA on tensors type definitions.
|
||||
// MHLO on tensors type definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Token type.
|
||||
|
@ -78,7 +78,7 @@ def HLO_StaticShapeTensor : StaticShapeTensorOf<[
|
|||
AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA on tensors combined type definitions.
|
||||
// MHLO on tensors combined type definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Any integer or floating-point tensor types
|
||||
|
@ -97,7 +97,7 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
|
|||
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA nullary op definitions.
|
||||
// MHLO nullary op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class BASE_HLO_ConstOp {
|
||||
|
@ -117,7 +117,7 @@ class BASE_HLO_IotaOp {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA unary elementwise op definitions.
|
||||
// MHLO unary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
||||
|
||||
|
|
|
@ -25,19 +25,19 @@ def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
|
|||
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
|
||||
|
||||
class ConstantSplat<string value> : NativeCodeCall<
|
||||
"xla::getSplat(&$_builder, $0, " # value # ")">;
|
||||
"hlo::getSplat(&$_builder, $0, " # value # ")">;
|
||||
|
||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||
|
||||
def BinBroadcastDimensions : NativeCodeCall<
|
||||
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
|
||||
"hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
|
||||
|
||||
def BinBroadcastDimensionsNonEmpty : NativeCodeCall<
|
||||
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">;
|
||||
"hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">;
|
||||
|
||||
// Here, the element type can be any integer or float type. But, note that only
|
||||
// 32 bit integers are supported for the value.
|
||||
class GetScalarOfType<int value> : NativeCodeCall<
|
||||
"xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||
"hlo::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||
|
||||
#endif // HLO_UTILS
|
||||
|
|
|
@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the operation definition file for LXLA.
|
||||
// This is the operation definition file for LMHLO, the "late" MHLO variant of
|
||||
// the dialect, which operates on buffers instead of tensors.
|
||||
//
|
||||
// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to
|
||||
// This file largely overlaps with mhlo_ops.td at a logic level. It's tempting to
|
||||
// merge these two files together, but we need to consider the following
|
||||
// obstacles:
|
||||
// * We need to have a common representation for arguments. That is to say,
|
||||
|
@ -43,7 +44,7 @@ def LHLO_Dialect : Dialect {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA type definitions.
|
||||
// LMHLO type definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Any integer tensor types
|
||||
|
@ -66,7 +67,7 @@ def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
|||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA nullary op definitions.
|
||||
// LMHLO nullary op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class LHLO_Op<string mnemonic, list<OpTrait> traits> :
|
||||
|
@ -86,7 +87,7 @@ def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA unary elementwise op definitions.
|
||||
// LMHLO unary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
||||
|
||||
|
@ -157,7 +158,7 @@ def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HL
|
|||
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA binary elementwise op definitions.
|
||||
// LMHLO binary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
||||
|
||||
|
@ -212,7 +213,7 @@ def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp;
|
|||
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA control flow op definitions.
|
||||
// LMHLO control flow op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/139813999): specify required function signature in a type-safe way.
|
||||
|
@ -284,7 +285,7 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA tuple op definitions.
|
||||
// LMHLO tuple op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
|
||||
|
@ -298,7 +299,7 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA Slice definitions.
|
||||
// LMHLO Slice definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LHLO_SliceOp: LHLO_Op<
|
||||
|
@ -483,7 +484,7 @@ def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
|
|||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA Other op definitions.
|
||||
// LMHLO Other op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,
|
||||
|
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||
|
@ -150,15 +150,14 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
|||
}
|
||||
|
||||
template <typename PredicateType>
|
||||
inline Optional<PredicateType> getCmpPredicate(
|
||||
StringRef xla_comparison_direction) {
|
||||
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
|
||||
StringRef xla_comparison_direction) {
|
||||
return llvm::StringSwitch<Optional<CmpFPredicate>>(xla_comparison_direction)
|
||||
StringRef comparison_direction) {
|
||||
return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
|
||||
.Case("EQ", CmpFPredicate::OEQ)
|
||||
.Case("NE", CmpFPredicate::ONE)
|
||||
.Case("GE", CmpFPredicate::OGE)
|
||||
|
@ -170,8 +169,8 @@ inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
|
|||
|
||||
template <>
|
||||
inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
|
||||
StringRef xla_comparison_direction) {
|
||||
return llvm::StringSwitch<Optional<CmpIPredicate>>(xla_comparison_direction)
|
||||
StringRef comparison_direction) {
|
||||
return llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
|
||||
.Case("EQ", CmpIPredicate::eq)
|
||||
.Case("NE", CmpIPredicate::ne)
|
||||
.Case("GE", CmpIPredicate::sge)
|
||||
|
@ -181,11 +180,11 @@ inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
|
|||
.Default(llvm::None);
|
||||
}
|
||||
|
||||
template <typename XLACompareOpTy>
|
||||
inline Value MapXlaCompareOpToStdScalarOp(Location loc,
|
||||
StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
template <typename CompareOpTy>
|
||||
inline Value MapCompareOpToStdScalarOp(Location loc,
|
||||
StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
const auto& lhs = args[0];
|
||||
const auto& rhs = args[1];
|
||||
Type element_type = lhs.getType();
|
||||
|
@ -193,15 +192,15 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc,
|
|||
Optional<CmpIPredicate> predicate =
|
||||
getCmpPredicate<CmpIPredicate>(comparison_direction);
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
return b->create<ScalarIOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||
rhs);
|
||||
return b->create<ScalarIOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||
rhs);
|
||||
}
|
||||
if (element_type.isa<FloatType>()) {
|
||||
Optional<CmpFPredicate> predicate =
|
||||
getCmpPredicate<CmpFPredicate>(comparison_direction);
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
return b->create<ScalarFOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||
rhs);
|
||||
return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||
rhs);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -337,10 +336,10 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
|
|||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
/// Implements the conversion of XLA op to scalar op (to use within region of a
|
||||
/// Implements the conversion of HLO op to scalar op (to use within region of a
|
||||
/// linalg.generic op) for compare-select style operations like min/max.
|
||||
template <typename... Args>
|
||||
struct XlaCompareSelectOpToStdScalarOp {
|
||||
struct CompareSelectOpToStdScalarOp {
|
||||
static Value map(Location loc, StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
|
@ -352,8 +351,8 @@ struct XlaCompareSelectOpToStdScalarOp {
|
|||
/// dialect with a given predicate based on the element type of the operand.
|
||||
template <typename SupportedType, typename StdCompareOp, typename Predicate,
|
||||
typename... Args>
|
||||
struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
||||
Args...> {
|
||||
struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
||||
Args...> {
|
||||
static Value map(Location loc, StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
|
@ -365,8 +364,8 @@ struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
|||
args[0], args[1]);
|
||||
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
|
||||
}
|
||||
return XlaCompareSelectOpToStdScalarOp<Args...>::map(
|
||||
loc, comparison_direction, result_types, args, b);
|
||||
return CompareSelectOpToStdScalarOp<Args...>::map(loc, comparison_direction,
|
||||
result_types, args, b);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -384,7 +383,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
|
|||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return XlaCompareSelectOpToStdScalarOp<
|
||||
return CompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types,
|
||||
args, b);
|
||||
|
@ -395,7 +394,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
|
|||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return XlaCompareSelectOpToStdScalarOp<
|
||||
return CompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types,
|
||||
args, b);
|
||||
|
@ -475,25 +474,25 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
|||
|
||||
} // namespace impl
|
||||
|
||||
struct XlaOpToStdScalarOp {
|
||||
struct HloOpToStdScalarOp {
|
||||
// Implementation for LHLO ops except lmhlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
|
||||
template <typename HloOpTy, typename LhloOpTy = HloOpTy,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
|
||||
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
|
||||
std::false_type>::value>>
|
||||
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||
static Value map(HloOpTy op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
// Implementation for HLO ops except mhlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
|
||||
template <typename HloOpTy, typename LhloOpTy = mhlo::HloToLhloOp<HloOpTy>,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
|
||||
!std::is_same<LhloOpTy, std::false_type>::value>>
|
||||
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||
static Value map(HloOpTy op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
|
||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
||||
args, b);
|
||||
|
@ -505,7 +504,7 @@ struct XlaOpToStdScalarOp {
|
|||
static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
|
||||
|
@ -516,7 +515,7 @@ struct XlaOpToStdScalarOp {
|
|||
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
};
|
||||
|
@ -524,4 +523,4 @@ struct XlaOpToStdScalarOp {
|
|||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
|
|
@ -56,7 +56,7 @@ std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
|
|||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
||||
|
||||
// fuse mhlo ops to kLoop/kInput fusion patterns
|
||||
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
||||
|
||||
} // namespace mhlo
|
||||
|
||||
|
@ -94,12 +94,12 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
|||
|
||||
} // namespace lmhlo
|
||||
|
||||
namespace xla {
|
||||
namespace hlo {
|
||||
|
||||
/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
|
||||
|
||||
} // namespace xla
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_
|
||||
|
|
|
@ -38,8 +38,8 @@ void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
|
|||
void PopulateComplexLoweringPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
|
||||
MLIRContext *ctx);
|
||||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
|
||||
void populateHLOToLHLOConversionPattern(
|
||||
|
@ -93,14 +93,14 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
|||
|
||||
} // namespace chlo
|
||||
|
||||
namespace xla {
|
||||
namespace hlo {
|
||||
|
||||
// Populates a pattern that translates the standard TanhOp to an approximation
|
||||
// that does not use intrinsics.
|
||||
void PopulateTanhToApproximationPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace xla
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
|
||||
|
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_
|
||||
|
||||
// Utilities relating to implementing HLO broadcasting.
|
||||
// Note: This file should not depend on any non-MLIR TensorFlow libraries.
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
|||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace hlo {
|
||||
|
||||
// Checks whether the given operand types and broadcast_dims attr represent a
|
||||
// legal combination for "numpy" style broadcasting (where 1-dims are prepended
|
||||
|
@ -43,7 +43,7 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
|
|||
Value rhs,
|
||||
OpBuilder& builder);
|
||||
|
||||
} // namespace xla
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_
|
||||
|
|
|
@ -13,21 +13,21 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace hlo {
|
||||
|
||||
// Converts the given elements attr to the specified elements type.
|
||||
// Requires type of the elements and new_type to be either integer or float
|
||||
// type.
|
||||
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||
mlir::Type new_type);
|
||||
} // namespace xla
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
|
||||
|
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
@ -162,4 +162,4 @@ class GraphCycles {
|
|||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_
|
||||
|
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace hlo {
|
||||
|
||||
// Computes the broadcast dimensions attr for an elementwise binary operator
|
||||
// between two ranked tensors.
|
||||
|
@ -68,7 +68,7 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
|
|||
// Requires `ty` to be either FloatType of IntegerType.
|
||||
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
|
||||
|
||||
} // namespace xla
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
|
||||
|
|
|
@ -137,7 +137,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
|
|||
auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
|
||||
.dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (broadcast_dimensions &&
|
||||
!xla::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
|
||||
!hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
|
||||
// Note: It is unclear whether the general specification of explicit
|
||||
// broadcast_dimensions on binary ops is a feature we want to carry
|
||||
// forward. While it can technically be implemented for ranked-dynamic,
|
||||
|
@ -150,7 +150,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
|
|||
<< "broadcast_dimensions = " << broadcast_dimensions;
|
||||
}
|
||||
|
||||
Value computed_shape = xla::ComputeBinaryElementwiseBroadcastingResultExtents(
|
||||
Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
|
||||
loc, lhs, rhs, builder);
|
||||
if (!computed_shape) return failure();
|
||||
reifiedReturnShapes.push_back(computed_shape);
|
||||
|
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
// Static initialization for XLA dialect registration.
|
||||
// Static initialization for *HLO dialects registration.
|
||||
static mlir::DialectRegistration<mlir::mhlo::MhloDialect> mhlo_ops;
|
||||
static mlir::DialectRegistration<mlir::chlo::HloClientDialect> chlo_ops;
|
||||
static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops;
|
||||
|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the operations used in the XLA dialect.
|
||||
// This file defines the operations used in the MHLO dialect.
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
||||
|
@ -404,7 +404,7 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
|||
|
||||
// If the operand is constant, we can do the conversion now.
|
||||
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
|
||||
return xla::ConvertElementsAttr(elementsAttr,
|
||||
return hlo::ConvertElementsAttr(elementsAttr,
|
||||
getElementTypeOrSelf(getResult()));
|
||||
}
|
||||
|
||||
|
@ -2135,8 +2135,6 @@ MhloDialect::MhloDialect(MLIRContext* context)
|
|||
>();
|
||||
addInterfaces<HLOInlinerInterface>();
|
||||
addTypes<TokenType>();
|
||||
// Support unknown operations because not all XLA operations are registered.
|
||||
// allowUnknownOperations();
|
||||
}
|
||||
|
||||
Type MhloDialect::parseType(DialectAsmParser& parser) const {
|
||||
|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the operations used in the XLA dialect.
|
||||
// This file defines the operations used in the LMHLO dialect.
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
|||
// Check for "numpy"-style rank broadcast.
|
||||
auto broadcast_dimensions = op.broadcast_dimensions();
|
||||
if (broadcast_dimensions &&
|
||||
!xla::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
|
||||
!hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
|
||||
// Note: It is unclear whether the general specification of explicit
|
||||
// broadcast_dimensions on binary ops is a feature we want to carry
|
||||
// forward. While it can technically be implemented for ranked-dynamic,
|
||||
|
@ -126,7 +126,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
|||
|
||||
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
||||
Value result_extents =
|
||||
xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
|
||||
hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
|
||||
rewriter);
|
||||
|
||||
// Note that we unconditionally emit DynamicBroadcastInDim ops and let
|
||||
|
|
|
@ -53,5 +53,5 @@ struct TestChloLegalizeToHloPass
|
|||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::chlo::TestChloLegalizeToHloPass> pass(
|
||||
"test-xla-chlo-legalize-to-hlo",
|
||||
"mhlo-test-chlo-legalize-to-hlo",
|
||||
"Test pass for applying chlo -> hlo legalization patterns");
|
||||
|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements logic for lowering XLA dialect to Standard dialect.
|
||||
// This file implements logic for lowering MHLO dialect to Standard dialect.
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||
|
@ -107,8 +107,8 @@ LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
|
|||
}
|
||||
|
||||
LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
|
||||
// Converts an XLA while loop into control flow. This generates a set of MLIR
|
||||
// blocks and branches, along with inlining the regions provided by the XLA
|
||||
// Converts a MHLO while loop into control flow. This generates a set of MLIR
|
||||
// blocks and branches, along with inlining the regions provided by the MHLO
|
||||
// while loop. The structure should be similar to below:
|
||||
//
|
||||
// <prior operations>
|
||||
|
@ -232,5 +232,5 @@ mlir::mhlo::createLegalizeControlFlowPass() {
|
|||
}
|
||||
|
||||
static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass(
|
||||
"xla-legalize-control-flow",
|
||||
"Legalize from XLA control flow to MLIR control flow");
|
||||
"mhlo-legalize-control-flow",
|
||||
"Legalize from MHLO control flow to CFG control flow");
|
||||
|
|
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace hlo {
|
||||
namespace {
|
||||
|
||||
/// Emits the fast tanh approximation that is also used by XLA.
|
||||
|
@ -149,8 +149,8 @@ void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
|
|||
}
|
||||
|
||||
static PassRegistration<LegalizeTanhToApproximation> legalize_pass(
|
||||
"xla-legalize-tanh-to-approximation",
|
||||
"mhlo-legalize-tanh-to-approximation",
|
||||
"Legalize tanh from standard dialect to an approximation");
|
||||
|
||||
} // namespace xla
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -32,7 +32,7 @@ limitations under the License.
|
|||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -49,12 +49,12 @@ Value getResultValue(Operation* op) {
|
|||
}
|
||||
|
||||
template <bool isLHLO = true>
|
||||
ShapedType getXLAOpResultType(Operation* op) {
|
||||
ShapedType getHloOpResultType(Operation* op) {
|
||||
return getResultValue<isLHLO>(op).getType().template cast<ShapedType>();
|
||||
}
|
||||
|
||||
template <bool isLHLO = true>
|
||||
bool verifyXLAOpBufferOrTensorSemantics(Operation* op) {
|
||||
bool verifyHloOpBufferOrTensorSemantics(Operation* op) {
|
||||
auto verifyType = [&](Value val) -> bool {
|
||||
return (isLHLO && val.getType().isa<MemRefType>()) ||
|
||||
(!isLHLO && val.getType().isa<RankedTensorType>());
|
||||
|
@ -133,7 +133,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
|||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
||||
// That method needs to be moved out of there.
|
||||
Value opResult = lmhlo::XlaOpToStdScalarOp::map<OpTy>(
|
||||
Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
||||
op, bodyResultTypes,
|
||||
llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
|
||||
|
@ -163,7 +163,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
|||
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
||||
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
||||
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
|
||||
Value opResult = lmhlo::XlaOpToStdScalarOp::map<LhloOp>(
|
||||
Value opResult = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
|
||||
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
||||
&rewriter);
|
||||
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
|
||||
|
@ -274,7 +274,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Base class for lowering xla operations that have one operand and one result,
|
||||
/// Base class for lowering HLO operations that have one operand and one result,
|
||||
/// and are semantically equivalent to a copy of the input to the output (like
|
||||
/// transpose, some reshape, etc.). The derived classes need to provide a method
|
||||
/// `getIndexingMaps` that returns AffineMaps for the index maps of the input
|
||||
|
@ -287,8 +287,8 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
|||
LogicalResult matchAndRewrite(
|
||||
OpTy op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
|
||||
auto resultType = getXLAOpResultType<isLHLO>(op);
|
||||
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
|
||||
auto resultType = getHloOpResultType<isLHLO>(op);
|
||||
|
||||
SmallVector<AffineMap, 2> indexing_maps =
|
||||
Derived::getIndexingMaps(op, &rewriter);
|
||||
|
@ -322,7 +322,7 @@ class BroadcastConverter
|
|||
ShapedType inputType =
|
||||
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||
unsigned inputRank = inputType.getRank();
|
||||
unsigned nloops = getXLAOpResultType<isLHLO>(broadcastOp).getRank();
|
||||
unsigned nloops = getHloOpResultType<isLHLO>(broadcastOp).getRank();
|
||||
|
||||
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
|
||||
// the input's dimensions.
|
||||
|
@ -356,7 +356,7 @@ class HloBroadcastInDimConverter
|
|||
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(
|
||||
mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
|
||||
auto resultType = getXLAOpResultType<false>(broadcastOp);
|
||||
auto resultType = getHloOpResultType<false>(broadcastOp);
|
||||
auto operandType =
|
||||
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||
unsigned nloops = resultType.getRank();
|
||||
|
@ -555,7 +555,7 @@ class TransposeConverter
|
|||
isLHLO>::DataMovementOpConverter;
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
||||
auto resultType =
|
||||
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
auto nloops = resultType.getRank();
|
||||
SmallVector<AffineExpr, 2> inputExprs;
|
||||
inputExprs.resize(resultType.getRank());
|
||||
|
@ -579,11 +579,11 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|||
LogicalResult matchAndRewrite(
|
||||
OpTy reshapeOp, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
|
||||
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
|
||||
return failure();
|
||||
ShapedType operandType =
|
||||
reshapeOp.operand().getType().template cast<ShapedType>();
|
||||
ShapedType resultType = getXLAOpResultType<isLHLO>(reshapeOp);
|
||||
ShapedType resultType = getHloOpResultType<isLHLO>(reshapeOp);
|
||||
|
||||
if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
@ -708,7 +708,7 @@ class ReverseConverter
|
|||
isLHLO>::DataMovementOpConverter;
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
||||
auto resultType =
|
||||
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
auto nloops = resultType.getRank();
|
||||
SmallVector<AffineExpr, 2> inputExprs;
|
||||
inputExprs.reserve(nloops);
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements logic for lowering XLA dialect to Standard dialect.
|
||||
// This file implements logic for lowering MHLO dialect to Standard dialect.
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
@ -187,8 +187,8 @@ std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
|
|||
return std::make_unique<LegalizeToStandard>();
|
||||
}
|
||||
|
||||
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
|
||||
mlir::MLIRContext *ctx) {
|
||||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
||||
mlir::MLIRContext *ctx) {
|
||||
mlir::populateWithGenerated(ctx, patterns);
|
||||
patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
|
||||
}
|
||||
|
@ -196,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
|
|||
/// Perform the lowering to standard dialect.
|
||||
void LegalizeToStandard::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::mhlo::PopulateXlaToStdPatterns(&patterns, &getContext());
|
||||
mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeToStandard> legalize_pass(
|
||||
"xla-legalize-to-std", "Legalize from XLA dialect to standard dialect");
|
||||
"mhlo-legalize-to-std", "Legalize from MHLO dialect to standard dialect");
|
||||
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the legalization pattern definition file for XLA to StandardOps.
|
||||
// This is the legalization pattern definition file for MHLO to StandardOps.
|
||||
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
|
|
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
|
|||
auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
|
||||
auto result =
|
||||
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
|
||||
Value op_result = lmhlo::XlaOpToStdScalarOp::map<DotOp>(
|
||||
Value op_result = lmhlo::HloOpToStdScalarOp::map<DotOp>(
|
||||
op, element_type, {l, r, result}, &builder);
|
||||
map_status = success(op_result != nullptr);
|
||||
if (failed(map_status)) return;
|
||||
|
@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
|||
ValueRange induction_vars) {
|
||||
auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
|
||||
auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
|
||||
Value op_result = lmhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
|
||||
Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOpTy>(
|
||||
op, element_type, {l, r}, &builder);
|
||||
map_status = success(op_result != nullptr);
|
||||
if (failed(map_status)) return;
|
||||
|
|
|
@ -35,7 +35,7 @@ limitations under the License.
|
|||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
|
|
@ -192,22 +192,22 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|||
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
lmhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
|
||||
lmhlo::ReduceOp reduce_op, ArrayRef<Value> /*args*/,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// TODO(b/137624192) Implement variadic reduce.
|
||||
if (xla_reduce_op.out().size() != 1) return failure();
|
||||
if (reduce_op.out().size() != 1) return failure();
|
||||
|
||||
scf::ReduceOp reduce_op =
|
||||
CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter);
|
||||
ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op,
|
||||
&xla_reduce_op.body().front(), &rewriter);
|
||||
rewriter.replaceOp(xla_reduce_op, llvm::None);
|
||||
scf::ReduceOp scf_reduce_op =
|
||||
CreateReduceOpInNestedParallelLoops(reduce_op, &rewriter);
|
||||
ConvertToReductionOperator(reduce_op.getLoc(), scf_reduce_op,
|
||||
&reduce_op.body().front(), &rewriter);
|
||||
rewriter.replaceOp(reduce_op, llvm::None);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
|
||||
// refers to the parallel dimensions of `xla_reduce_op` if any and the inner
|
||||
// refers to the parallel dimensions of `reduce_op` if any and the inner
|
||||
// ParallelOp refers to the reduction dimensions. The scf.reduce op is
|
||||
// returned.
|
||||
//
|
||||
|
@ -226,16 +226,15 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|||
// scf.yield
|
||||
// }
|
||||
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
|
||||
lmhlo::ReduceOp xla_reduce_op,
|
||||
ConversionPatternRewriter* rewriter) const {
|
||||
auto loc = xla_reduce_op.getLoc();
|
||||
lmhlo::ReduceOp reduce_op, ConversionPatternRewriter* rewriter) const {
|
||||
auto loc = reduce_op.getLoc();
|
||||
DenseSet<int> reducing_dims;
|
||||
for (const auto& rdim : xla_reduce_op.dimensions().getIntValues()) {
|
||||
for (const auto& rdim : reduce_op.dimensions().getIntValues()) {
|
||||
reducing_dims.insert(rdim.getSExtValue());
|
||||
}
|
||||
|
||||
Value operand = *xla_reduce_op.operands().begin();
|
||||
Value out = *xla_reduce_op.out().begin();
|
||||
Value operand = *reduce_op.operands().begin();
|
||||
Value out = *reduce_op.out().begin();
|
||||
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
|
||||
SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
|
||||
auto operand_shape = operand.getType().cast<MemRefType>().getShape();
|
||||
|
@ -252,7 +251,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|||
}
|
||||
// Load initial value from memref<element_type>.
|
||||
SmallVector<Value, 1> init_value = {
|
||||
rewriter->create<LoadOp>(loc, *xla_reduce_op.init_values().begin())};
|
||||
rewriter->create<LoadOp>(loc, *reduce_op.init_values().begin())};
|
||||
// Outer ParallelOp is not needed if it is a reduction across all dims.
|
||||
scf::ParallelOp outer;
|
||||
if (!parallel_lower.empty()) {
|
||||
|
@ -293,7 +292,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|||
|
||||
rewriter->setInsertionPointToStart(inner.getBody());
|
||||
Value elem = rewriter->create<mlir::LoadOp>(
|
||||
loc, *xla_reduce_op.operands().begin(), indices);
|
||||
loc, *reduce_op.operands().begin(), indices);
|
||||
return rewriter->create<scf::ReduceOp>(loc, elem);
|
||||
}
|
||||
};
|
||||
|
@ -364,42 +363,42 @@ class ReduceWindowOpConverter
|
|||
using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
lmhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
|
||||
lmhlo::ReduceWindowOp reduce_window_op, ArrayRef<Value> /*args*/,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
scf::ParallelOp output_loop, window_loop;
|
||||
std::tie(output_loop, window_loop) =
|
||||
CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op,
|
||||
CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op,
|
||||
&rewriter);
|
||||
|
||||
scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
|
||||
xla_reduce_window_op, output_loop, window_loop, &rewriter);
|
||||
reduce_window_op, output_loop, window_loop, &rewriter);
|
||||
|
||||
ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op,
|
||||
&xla_reduce_window_op.body().front(), &rewriter);
|
||||
rewriter.replaceOp(xla_reduce_window_op, llvm::None);
|
||||
ConvertToReductionOperator(reduce_window_op.getLoc(), reduce_op,
|
||||
&reduce_window_op.body().front(), &rewriter);
|
||||
rewriter.replaceOp(reduce_window_op, llvm::None);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
std::pair<scf::ParallelOp, scf::ParallelOp>
|
||||
CreateParallelLoopsToTraverseOutputAndWindow(
|
||||
lmhlo::ReduceWindowOp xla_reduce_window_op,
|
||||
lmhlo::ReduceWindowOp reduce_window_op,
|
||||
ConversionPatternRewriter* rewriter) const {
|
||||
auto loc = xla_reduce_window_op.getLoc();
|
||||
auto loc = reduce_window_op.getLoc();
|
||||
Value init_value =
|
||||
rewriter->create<LoadOp>(loc, xla_reduce_window_op.init_value());
|
||||
rewriter->create<LoadOp>(loc, reduce_window_op.init_value());
|
||||
|
||||
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
|
||||
|
||||
// Create an outer parallel loop that spans the output of ReduceWindowOp.
|
||||
Value xla_output = xla_reduce_window_op.out();
|
||||
auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter);
|
||||
Value output = reduce_window_op.out();
|
||||
auto output_loop = MakeLoopOverShape(loc, output, rewriter);
|
||||
|
||||
// Create a nested loop that traverses the window.
|
||||
SmallVector<Value, 2> window_lower, window_upper, window_step;
|
||||
rewriter->setInsertionPointToStart(output_loop.getBody());
|
||||
for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) {
|
||||
for (const auto& window_dim : reduce_window_op.window_dimensions()) {
|
||||
window_step.push_back(one);
|
||||
window_lower.push_back(zero);
|
||||
window_upper.push_back(
|
||||
|
@ -410,38 +409,38 @@ class ReduceWindowOpConverter
|
|||
|
||||
Value reduction_result = *window_loop.getResults().begin();
|
||||
auto output_ivs = output_loop.getInductionVars();
|
||||
rewriter->create<StoreOp>(loc, reduction_result, xla_output, output_ivs);
|
||||
rewriter->create<StoreOp>(loc, reduction_result, output, output_ivs);
|
||||
return std::make_pair(output_loop, window_loop);
|
||||
}
|
||||
|
||||
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
|
||||
lmhlo::ReduceWindowOp xla_reduce_window_op, scf::ParallelOp output_loop,
|
||||
lmhlo::ReduceWindowOp reduce_window_op, scf::ParallelOp output_loop,
|
||||
scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
|
||||
rewriter->setInsertionPointToStart(window_loop.getBody());
|
||||
auto loc = xla_reduce_window_op.getLoc();
|
||||
auto loc = reduce_window_op.getLoc();
|
||||
|
||||
if (xla_reduce_window_op.base_dilations().hasValue() ||
|
||||
xla_reduce_window_op.window_dilations().hasValue()) {
|
||||
xla_reduce_window_op.emitRemark(
|
||||
if (reduce_window_op.base_dilations().hasValue() ||
|
||||
reduce_window_op.window_dilations().hasValue()) {
|
||||
reduce_window_op.emitRemark(
|
||||
"Lowering to parallel loops does not support `base_dilations` or "
|
||||
"`window_dilations` attributes yet. The attributes will be ignored.");
|
||||
}
|
||||
|
||||
Value xla_operand = xla_reduce_window_op.operand();
|
||||
auto xla_operand_type = xla_operand.getType().cast<MemRefType>();
|
||||
Value operand = reduce_window_op.operand();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
|
||||
// Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
|
||||
MappedIvs mapped_ivs = MapWindowIvsToInput(
|
||||
xla_reduce_window_op, output_loop.getInductionVars(),
|
||||
window_loop.getInductionVars(), rewriter);
|
||||
MappedIvs mapped_ivs =
|
||||
MapWindowIvsToInput(reduce_window_op, output_loop.getInductionVars(),
|
||||
window_loop.getInductionVars(), rewriter);
|
||||
|
||||
auto elem_or_init = rewriter->create<scf::IfOp>(
|
||||
loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds,
|
||||
loc, operand_type.getElementType(), mapped_ivs.in_bounds,
|
||||
/*withElseRegion=*/true);
|
||||
|
||||
OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
|
||||
Value elem = then_builder.create<mlir::LoadOp>(
|
||||
loc, xla_reduce_window_op.operand(), mapped_ivs.ivs);
|
||||
loc, reduce_window_op.operand(), mapped_ivs.ivs);
|
||||
then_builder.create<scf::YieldOp>(loc, elem);
|
||||
|
||||
OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
|
||||
|
|
|
@ -45,13 +45,13 @@ class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> {
|
|||
public:
|
||||
explicit LowerComplex() : PassWrapper<LowerComplex, FunctionPass>() {}
|
||||
|
||||
/// Performs the lowering to XLA dialect.
|
||||
/// Performs the lowering to MHLO dialect.
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace hlo {
|
||||
namespace {
|
||||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc"
|
||||
|
@ -62,18 +62,18 @@ void PopulateComplexLoweringPatterns(MLIRContext* context,
|
|||
OwningRewritePatternList* patterns) {
|
||||
populateWithGenerated(context, patterns);
|
||||
}
|
||||
} // end namespace xla
|
||||
} // end namespace hlo
|
||||
} // end namespace mlir
|
||||
|
||||
// Lowers the complex operations that can be represented using other operations.
|
||||
void LowerComplex::runOnFunction() {
|
||||
// Add lowering patterns to the list.
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns);
|
||||
mlir::hlo::PopulateComplexLoweringPatterns(&getContext(), &patterns);
|
||||
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<LowerComplex> pass(
|
||||
"test-xla-lower-complex",
|
||||
"mhlo-test-lower-complex",
|
||||
"Lower complex operations into non-complex operations");
|
||||
|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements logic for lowering XLA general dot to a regular dot.
|
||||
// This file implements logic for lowering MHLO general dot to a regular dot.
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||
|
@ -188,5 +188,5 @@ void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
|
|||
}
|
||||
|
||||
static PassRegistration<LegalizeGeneralDot> legalize_pass(
|
||||
"test-xla-lower-general-dot",
|
||||
"mhlo-test-lower-general-dot",
|
||||
"Tests lowering general dot to a non-batched dot when possible");
|
||||
|
|
|
@ -54,5 +54,5 @@ struct TestMaterializeBroadcastsPass
|
|||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass(
|
||||
"test-xla-materialize-broadcasts",
|
||||
"mhlo-test-materialize-broadcasts",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
|
|
|
@ -479,7 +479,7 @@ class FusionPlanner {
|
|||
EquivalenceClasses<int32_t> leader_for_node_;
|
||||
};
|
||||
|
||||
struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
||||
struct MhloFusion : public mlir::PassWrapper<MhloFusion, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
FuncOp func = getFunction();
|
||||
if (!IsTargetFunc(func)) {
|
||||
|
@ -568,12 +568,12 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
|||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
|
||||
return std::make_unique<XlaHloFusion>();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createMhloFusion() {
|
||||
return std::make_unique<MhloFusion>();
|
||||
}
|
||||
|
||||
static PassRegistration<XlaHloFusion> mhlo_fusion_pass(
|
||||
"xla-hlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
|
||||
static PassRegistration<MhloFusion> mhlo_fusion_pass(
|
||||
"mhlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -71,7 +71,7 @@ class SinkConstantsToControlFlow
|
|||
};
|
||||
|
||||
static mlir::PassRegistration<SinkConstantsToControlFlow> pass(
|
||||
"xla-hlo-sink-constants-to-control-flow",
|
||||
"mhlo-sink-constants-to-control-flow",
|
||||
"Sink constants implicitly captured in control flow regions. This is "
|
||||
"necessary to export to XLA.");
|
||||
|
||||
|
|
|
@ -22,12 +22,12 @@ limitations under the License.
|
|||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace hlo {
|
||||
namespace {
|
||||
|
||||
struct InferReturnTypeComponentsPattern : public RewritePattern {
|
||||
InferReturnTypeComponentsPattern(MLIRContext *context)
|
||||
: RewritePattern("xla_test.get_return_type_components", 1, context) {}
|
||||
: RewritePattern("mhlo_test.get_return_type_components", 1, context) {}
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->getNumOperands() != 1) return failure();
|
||||
|
@ -44,7 +44,7 @@ struct InferReturnTypeComponentsPattern : public RewritePattern {
|
|||
}
|
||||
|
||||
// Replace the op with another pass-through op with attributes added.
|
||||
OperationState state(op->getLoc(), "xla_test.return_type_components",
|
||||
OperationState state(op->getLoc(), "mhlo_test.return_type_components",
|
||||
op->getOperands(), op->getResultTypes(),
|
||||
op->getAttrs());
|
||||
auto new_op = rewriter.createOperation(state);
|
||||
|
@ -65,7 +65,7 @@ struct InferReturnTypeComponentsPattern : public RewritePattern {
|
|||
|
||||
struct ReifyReturnTypeShapesPattern : public RewritePattern {
|
||||
ReifyReturnTypeShapesPattern(MLIRContext *context)
|
||||
: RewritePattern("xla_test.reify_return_type_shapes", 1, context) {}
|
||||
: RewritePattern("mhlo_test.reify_return_type_shapes", 1, context) {}
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->getNumOperands() != 1) return failure();
|
||||
|
@ -92,9 +92,9 @@ struct TestInferShapedTypeMethodsPass
|
|||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::xla::TestInferShapedTypeMethodsPass> pass(
|
||||
"test-xla-infer-shaped-type-methods",
|
||||
static mlir::PassRegistration<mlir::hlo::TestInferShapedTypeMethodsPass> pass(
|
||||
"mhlo-test-infer-shaped-type-methods",
|
||||
"Uses test ops to invoke InferShapedTypeOpInterface methods");
|
||||
|
|
|
@ -42,5 +42,5 @@ struct TestUnfuseBatchNormPass
|
|||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass(
|
||||
"test-xla-unfuse-batch-norm",
|
||||
"mhlo-test-unfuse-batch-norm",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
|
|
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace hlo {
|
||||
|
||||
bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
|
||||
DenseIntElementsAttr broadcast_dims) {
|
||||
|
@ -70,5 +70,5 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
|
|||
result_shape_v);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace hlo {
|
||||
|
||||
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||
mlir::Type new_type) {
|
||||
|
@ -82,5 +82,5 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
|||
}));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace hlo {
|
||||
|
||||
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y,
|
||||
bool allow_empty) {
|
||||
|
@ -66,5 +66,5 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
|||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -mhlo-test-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @broadcast_add
|
||||
// Note that all broadcast_ops are expanded from the same template, so
|
||||
|
@ -12,7 +12,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
|
|||
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
|
||||
// CHECK: return %[[EXTENTS]]
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
|
||||
%1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
|
||||
return %1 : tensor<1xindex>
|
||||
}
|
||||
|
||||
|
@ -20,8 +20,8 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
|
|||
// CHECK-LABEL: @complex_ranked_components
|
||||
func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
|
||||
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
|
||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
|
||||
%1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
|
||||
return %1 : tensor<?x?xcomplex<f32>>
|
||||
}
|
||||
|
||||
|
@ -29,8 +29,8 @@ func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
|
|||
// CHECK-LABEL: @compare_ranked_components
|
||||
func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
|
||||
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
|
||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
// CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
|
||||
%1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
return %0 : tensor<?x?xi1>
|
||||
}
|
||||
|
||||
|
@ -38,8 +38,8 @@ func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
|
|||
// CHECK-LABEL: @broadcast_add_ranked_components_r1
|
||||
func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
|
||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
|
||||
%1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -49,8 +49,8 @@ func @broadcast_add_ranked_components_r1x2(%arg0: tensor<?xf32>, %arg1: tensor<?
|
|||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
|
||||
// TODO: Overly broad shapes are being returned. Tighten the calculation
|
||||
// and update/extend these tests.
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}
|
||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x3xf32>) -> tensor<?x3xf32>
|
||||
// CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}
|
||||
%1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?x3xf32>) -> tensor<?x3xf32>
|
||||
return %1 : tensor<?x3xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
// Check the non-broadcast case for each registered op, then just check a
|
||||
// representative op for detailed broadcast semantics.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -xla-legalize-control-flow %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -mhlo-legalize-control-flow %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -xla-legalize-to-std %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -mhlo-legalize-to-std %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -mhlo-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
|
||||
|
||||
func @tanh_f64(%arg0 : f64) -> f64 {
|
||||
%res = tanh %arg0 : f64
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// GenericAtomicRMWOp should contain only ops with no side effects.
|
||||
// Unfortunately, the legalization pattern for SelectAndScatterOp has to adapt
|
||||
// to XLA LHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body.
|
||||
// to LMHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body.
|
||||
// Lowering to STD dialect and store forwarding pass would be required to get
|
||||
// rid of them. This is exactly what is done in the real MLIR GPU pipeline, but
|
||||
// here we disable verification with `verify-each=0` to check the output IR.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @add
|
||||
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -mhlo-test-lower-general-dot -split-input-file %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @testDebatch1
|
||||
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -mhlo-test-materialize-broadcasts -split-input-file %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @clampBroadcast
|
||||
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -mhlo-fusion -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @multi_outputs_same
|
||||
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -mhlo-sink-constants-to-control-flow | FileCheck %s
|
||||
|
||||
// Tests sinking constants to a while loop.
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s
|
||||
// RUN: mlir-hlo-opt -split-input-file -mhlo-test-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s
|
||||
|
||||
// CHECK-LABEL: @batchNormInference_2D_inner_features
|
||||
// CHECK-SAME: %[[X:[^:[:space:]]+]]
|
||||
|
|
Loading…
Reference in New Issue