Cleanup build rule names in compiler/mlir/hlo to remove the redundant/obsolete xla_ prefix

PiperOrigin-RevId: 320320140
This commit is contained in:
Mehdi Amini 2020-07-09 03:32:16 +00:00 committed by Mehdi Amini
parent f4303855c4
commit 506ddd9c4a
51 changed files with 222 additions and 228 deletions

View File

@ -38,7 +38,7 @@ def HLOClient_Dialect : Dialect {
let name = "chlo"; let name = "chlo";
let cppNamespace = "chlo"; let cppNamespace = "chlo";
let summary = [{ let summary = [{
XLA Client HLO Ops Client HLO Ops
}]; }];
let description = [{ let description = [{
@ -60,7 +60,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions. // CHLO binary elementwise op definitions.
// From the client perspective, each of these support both explicit rank // From the client perspective, each of these support both explicit rank
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate // broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
// shape broadcasting. // shape broadcasting.

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This file defines the operations used in the XLA dialect. // This file defines the operations used in the MHLO dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_

View File

@ -13,10 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This is the operation definition file for XLA HLO ops which map to the // This is the operation definition file for MHLO ops.
// traditional definition in xla_data.proto (or are aligned with the goals
// thereof).
// See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto
#ifndef HLO_OPS #ifndef HLO_OPS
#define HLO_OPS #define HLO_OPS
@ -44,7 +41,7 @@ class HLO_Op<string mnemonic, list<OpTrait> traits> :
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA nullary op definitions. // MHLO nullary op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def HLO_ConstOp : HLO_Op<"constant", def HLO_ConstOp : HLO_Op<"constant",
@ -113,7 +110,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA unary elementwise op definitions. // MHLO unary elementwise op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
@ -264,7 +261,7 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
HLO_FpOrComplexTensor>, BASE_HLO_TanhOp; HLO_FpOrComplexTensor>, BASE_HLO_TanhOp;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions. // MHLO binary elementwise op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
@ -363,7 +360,7 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract",
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA binary logical elementwise op definitions. // MHLO binary logical elementwise op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
@ -381,7 +378,7 @@ def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp;
def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp; def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA communication op definitions. // MHLO communication op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'. // InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'.
@ -481,7 +478,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA parallelism related op definitions. // MHLO parallelism related op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
@ -492,7 +489,7 @@ def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA control flow op definitions. // MHLO control flow op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> { def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> {
@ -640,7 +637,7 @@ def HLO_ReduceOp: HLO_Op<"reduce", [
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA tuple op definitions. // MHLO tuple op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO_GetTupleElementOp { def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO_GetTupleElementOp {
let arguments = (ins let arguments = (ins
@ -684,7 +681,7 @@ def HLO_CompareOp: HLO_Op<"compare",
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA Slice definitions. // MHLO Slice definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def HLO_SliceOp: HLO_Op< def HLO_SliceOp: HLO_Op<
@ -745,7 +742,7 @@ def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice",
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA Other op definitions. // MHLO Other op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>, def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>,
@ -1320,7 +1317,7 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA RngUniform Operator. // MHLO RngUniform Operator.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp { def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
let arguments = (ins let arguments = (ins
@ -1347,7 +1344,7 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA Quantize Operator. // MHLO Quantize Operator.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>, def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>,
BASE_HLO_DequantizeOp { BASE_HLO_DequantizeOp {

View File

@ -35,7 +35,7 @@ def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
defvar BroadcastDimAttr = I64ElementsAttr; defvar BroadcastDimAttr = I64ElementsAttr;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA on tensors type definitions. // MHLO on tensors type definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Token type. // Token type.
@ -78,7 +78,7 @@ def HLO_StaticShapeTensor : StaticShapeTensorOf<[
AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA on tensors combined type definitions. // MHLO on tensors combined type definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Any integer or floating-point tensor types // Any integer or floating-point tensor types
@ -97,7 +97,7 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA nullary op definitions. // MHLO nullary op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class BASE_HLO_ConstOp { class BASE_HLO_ConstOp {
@ -117,7 +117,7 @@ class BASE_HLO_IotaOp {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA unary elementwise op definitions. // MHLO unary elementwise op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions

View File

@ -25,19 +25,19 @@ def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">; def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
class ConstantSplat<string value> : NativeCodeCall< class ConstantSplat<string value> : NativeCodeCall<
"xla::getSplat(&$_builder, $0, " # value # ")">; "hlo::getSplat(&$_builder, $0, " # value # ")">;
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
def BinBroadcastDimensions : NativeCodeCall< def BinBroadcastDimensions : NativeCodeCall<
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">; "hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
def BinBroadcastDimensionsNonEmpty : NativeCodeCall< def BinBroadcastDimensionsNonEmpty : NativeCodeCall<
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">; "hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">;
// Here, the element type can be any integer or float type. But, note that only // Here, the element type can be any integer or float type. But, note that only
// 32 bit integers are supported for the value. // 32 bit integers are supported for the value.
class GetScalarOfType<int value> : NativeCodeCall< class GetScalarOfType<int value> : NativeCodeCall<
"xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; "hlo::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
#endif // HLO_UTILS #endif // HLO_UTILS

View File

@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This is the operation definition file for LXLA. // This is the operation definition file for LMHLO, the "late" MHLO variant of
// the dialect, which operates on buffers instead of tensors.
// //
// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to // This file largely overlaps with mhlo_ops.td at a logic level. It's tempting to
// merge these two files together, but we need to consider the following // merge these two files together, but we need to consider the following
// obstacles: // obstacles:
// * We need to have a common representation for arguments. That is to say, // * We need to have a common representation for arguments. That is to say,
@ -43,7 +44,7 @@ def LHLO_Dialect : Dialect {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA type definitions. // LMHLO type definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Any integer tensor types // Any integer tensor types
@ -66,7 +67,7 @@ def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA nullary op definitions. // LMHLO nullary op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class LHLO_Op<string mnemonic, list<OpTrait> traits> : class LHLO_Op<string mnemonic, list<OpTrait> traits> :
@ -86,7 +87,7 @@ def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA unary elementwise op definitions. // LMHLO unary elementwise op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
@ -157,7 +158,7 @@ def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HL
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp; def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions. // LMHLO binary elementwise op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
@ -212,7 +213,7 @@ def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp;
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp; def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA control flow op definitions. // LMHLO control flow op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TODO(b/139813999): specify required function signature in a type-safe way. // TODO(b/139813999): specify required function signature in a type-safe way.
@ -284,7 +285,7 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA tuple op definitions. // LMHLO tuple op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
@ -298,7 +299,7 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA Slice definitions. // LMHLO Slice definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def LHLO_SliceOp: LHLO_Op< def LHLO_SliceOp: LHLO_Op<
@ -483,7 +484,7 @@ def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA Other op definitions. // LMHLO Other op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>, def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
@ -150,15 +150,14 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
} }
template <typename PredicateType> template <typename PredicateType>
inline Optional<PredicateType> getCmpPredicate( inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
StringRef xla_comparison_direction) {
return llvm::None; return llvm::None;
} }
template <> template <>
inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>( inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
StringRef xla_comparison_direction) { StringRef comparison_direction) {
return llvm::StringSwitch<Optional<CmpFPredicate>>(xla_comparison_direction) return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
.Case("EQ", CmpFPredicate::OEQ) .Case("EQ", CmpFPredicate::OEQ)
.Case("NE", CmpFPredicate::ONE) .Case("NE", CmpFPredicate::ONE)
.Case("GE", CmpFPredicate::OGE) .Case("GE", CmpFPredicate::OGE)
@ -170,8 +169,8 @@ inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
template <> template <>
inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>( inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
StringRef xla_comparison_direction) { StringRef comparison_direction) {
return llvm::StringSwitch<Optional<CmpIPredicate>>(xla_comparison_direction) return llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
.Case("EQ", CmpIPredicate::eq) .Case("EQ", CmpIPredicate::eq)
.Case("NE", CmpIPredicate::ne) .Case("NE", CmpIPredicate::ne)
.Case("GE", CmpIPredicate::sge) .Case("GE", CmpIPredicate::sge)
@ -181,11 +180,11 @@ inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
.Default(llvm::None); .Default(llvm::None);
} }
template <typename XLACompareOpTy> template <typename CompareOpTy>
inline Value MapXlaCompareOpToStdScalarOp(Location loc, inline Value MapCompareOpToStdScalarOp(Location loc,
StringRef comparison_direction, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
const auto& lhs = args[0]; const auto& lhs = args[0];
const auto& rhs = args[1]; const auto& rhs = args[1];
Type element_type = lhs.getType(); Type element_type = lhs.getType();
@ -193,15 +192,15 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc,
Optional<CmpIPredicate> predicate = Optional<CmpIPredicate> predicate =
getCmpPredicate<CmpIPredicate>(comparison_direction); getCmpPredicate<CmpIPredicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction"); assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarIOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs, return b->create<ScalarIOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
rhs); rhs);
} }
if (element_type.isa<FloatType>()) { if (element_type.isa<FloatType>()) {
Optional<CmpFPredicate> predicate = Optional<CmpFPredicate> predicate =
getCmpPredicate<CmpFPredicate>(comparison_direction); getCmpPredicate<CmpFPredicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction"); assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarFOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs, return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
rhs); rhs);
} }
return nullptr; return nullptr;
} }
@ -337,10 +336,10 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
loc, result_types, args, b); loc, result_types, args, b);
} }
/// Implements the conversion of XLA op to scalar op (to use within region of a /// Implements the conversion of HLO op to scalar op (to use within region of a
/// linalg.generic op) for compare-select style operations like min/max. /// linalg.generic op) for compare-select style operations like min/max.
template <typename... Args> template <typename... Args>
struct XlaCompareSelectOpToStdScalarOp { struct CompareSelectOpToStdScalarOp {
static Value map(Location loc, StringRef comparison_direction, static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
@ -352,8 +351,8 @@ struct XlaCompareSelectOpToStdScalarOp {
/// dialect with a given predicate based on the element type of the operand. /// dialect with a given predicate based on the element type of the operand.
template <typename SupportedType, typename StdCompareOp, typename Predicate, template <typename SupportedType, typename StdCompareOp, typename Predicate,
typename... Args> typename... Args>
struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate, struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
Args...> { Args...> {
static Value map(Location loc, StringRef comparison_direction, static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
@ -365,8 +364,8 @@ struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
args[0], args[1]); args[0], args[1]);
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]); return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
} }
return XlaCompareSelectOpToStdScalarOp<Args...>::map( return CompareSelectOpToStdScalarOp<Args...>::map(loc, comparison_direction,
loc, comparison_direction, result_types, args, b); result_types, args, b);
} }
}; };
@ -384,7 +383,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp< return CompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types, ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types,
args, b); args, b);
@ -395,7 +394,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp< return CompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types, ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types,
args, b); args, b);
@ -475,25 +474,25 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
} // namespace impl } // namespace impl
struct XlaOpToStdScalarOp { struct HloOpToStdScalarOp {
// Implementation for LHLO ops except lmhlo::CompareOp. // Implementation for LHLO ops except lmhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy, template <typename HloOpTy, typename LhloOpTy = HloOpTy,
typename = std::enable_if_t< typename = std::enable_if_t<
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value && !std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>, std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>> std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types, static Value map(HloOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) { ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types, return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
args, b); args, b);
} }
// Implementation for HLO ops except mhlo::CompareOp. // Implementation for HLO ops except mhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>, template <typename HloOpTy, typename LhloOpTy = mhlo::HloToLhloOp<HloOpTy>,
typename = std::enable_if_t< typename = std::enable_if_t<
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value && !std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, std::false_type>::value>> !std::is_same<LhloOpTy, std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types, static Value map(HloOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, int i = 0) { ArrayRef<Value> args, OpBuilder* b, int i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types, return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
args, b); args, b);
@ -505,7 +504,7 @@ struct XlaOpToStdScalarOp {
static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types, static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction(); auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>( return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b); op.getLoc(), comparison_direction, result_types, args, b);
} }
@ -516,7 +515,7 @@ struct XlaOpToStdScalarOp {
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types, static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction(); auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>( return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b); op.getLoc(), comparison_direction, result_types, args, b);
} }
}; };
@ -524,4 +523,4 @@ struct XlaOpToStdScalarOp {
} // namespace lmhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_

View File

@ -56,7 +56,7 @@ std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass(); std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
// fuse mhlo ops to kLoop/kInput fusion patterns // fuse mhlo ops to kLoop/kInput fusion patterns
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass(); std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
} // namespace mhlo } // namespace mhlo
@ -94,12 +94,12 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
} // namespace lmhlo } // namespace lmhlo
namespace xla { namespace hlo {
/// Lowers the standard TanhOp to an approximation that does not use intrinsics. /// Lowers the standard TanhOp to an approximation that does not use intrinsics.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
} // namespace xla } // namespace hlo
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_

View File

@ -38,8 +38,8 @@ void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
void PopulateComplexLoweringPatterns(MLIRContext *context, void PopulateComplexLoweringPatterns(MLIRContext *context,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx); MLIRContext *ctx);
// Collection of rewrite patterns for lowering of HLO to LHLO dialect. // Collection of rewrite patterns for lowering of HLO to LHLO dialect.
void populateHLOToLHLOConversionPattern( void populateHLOToLHLOConversionPattern(
@ -93,14 +93,14 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
} // namespace chlo } // namespace chlo
namespace xla { namespace hlo {
// Populates a pattern that translates the standard TanhOp to an approximation // Populates a pattern that translates the standard TanhOp to an approximation
// that does not use intrinsics. // that does not use intrinsics.
void PopulateTanhToApproximationPatterns(MLIRContext *context, void PopulateTanhToApproximationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);
} // namespace xla } // namespace hlo
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_
// Utilities relating to implementing HLO broadcasting. // Utilities relating to implementing HLO broadcasting.
// Note: This file should not depend on any non-MLIR TensorFlow libraries. // Note: This file should not depend on any non-MLIR TensorFlow libraries.
@ -27,7 +27,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
namespace mlir { namespace mlir {
namespace xla { namespace hlo {
// Checks whether the given operand types and broadcast_dims attr represent a // Checks whether the given operand types and broadcast_dims attr represent a
// legal combination for "numpy" style broadcasting (where 1-dims are prepended // legal combination for "numpy" style broadcasting (where 1-dims are prepended
@ -43,7 +43,7 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
Value rhs, Value rhs,
OpBuilder& builder); OpBuilder& builder);
} // namespace xla } // namespace hlo
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_

View File

@ -13,21 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
namespace mlir { namespace mlir {
namespace xla { namespace hlo {
// Converts the given elements attr to the specified elements type. // Converts the given elements attr to the specified elements type.
// Requires type of the elements and new_type to be either integer or float // Requires type of the elements and new_type to be either integer or float
// type. // type.
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type); mlir::Type new_type);
} // namespace xla } // namespace hlo
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_
#include <vector> #include <vector>
@ -162,4 +162,4 @@ class GraphCycles {
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
namespace mlir { namespace mlir {
namespace xla { namespace hlo {
// Computes the broadcast dimensions attr for an elementwise binary operator // Computes the broadcast dimensions attr for an elementwise binary operator
// between two ranked tensors. // between two ranked tensors.
@ -68,7 +68,7 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
// Requires `ty` to be either FloatType of IntegerType. // Requires `ty` to be either FloatType of IntegerType.
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value); DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
} // namespace xla } // namespace hlo
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_

View File

@ -137,7 +137,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
auto broadcast_dimensions = op->getAttr("broadcast_dimensions") auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
.dyn_cast_or_null<DenseIntElementsAttr>(); .dyn_cast_or_null<DenseIntElementsAttr>();
if (broadcast_dimensions && if (broadcast_dimensions &&
!xla::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) { !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
// Note: It is unclear whether the general specification of explicit // Note: It is unclear whether the general specification of explicit
// broadcast_dimensions on binary ops is a feature we want to carry // broadcast_dimensions on binary ops is a feature we want to carry
// forward. While it can technically be implemented for ranked-dynamic, // forward. While it can technically be implemented for ranked-dynamic,
@ -150,7 +150,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
<< "broadcast_dimensions = " << broadcast_dimensions; << "broadcast_dimensions = " << broadcast_dimensions;
} }
Value computed_shape = xla::ComputeBinaryElementwiseBroadcastingResultExtents( Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
loc, lhs, rhs, builder); loc, lhs, rhs, builder);
if (!computed_shape) return failure(); if (!computed_shape) return failure();
reifiedReturnShapes.push_back(computed_shape); reifiedReturnShapes.push_back(computed_shape);

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
// Static initialization for XLA dialect registration. // Static initialization for *HLO dialects registration.
static mlir::DialectRegistration<mlir::mhlo::MhloDialect> mhlo_ops; static mlir::DialectRegistration<mlir::mhlo::MhloDialect> mhlo_ops;
static mlir::DialectRegistration<mlir::chlo::HloClientDialect> chlo_ops; static mlir::DialectRegistration<mlir::chlo::HloClientDialect> chlo_ops;
static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops; static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops;

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This file defines the operations used in the XLA dialect. // This file defines the operations used in the MHLO dialect.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
@ -404,7 +404,7 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
// If the operand is constant, we can do the conversion now. // If the operand is constant, we can do the conversion now.
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) { if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
return xla::ConvertElementsAttr(elementsAttr, return hlo::ConvertElementsAttr(elementsAttr,
getElementTypeOrSelf(getResult())); getElementTypeOrSelf(getResult()));
} }
@ -2135,8 +2135,6 @@ MhloDialect::MhloDialect(MLIRContext* context)
>(); >();
addInterfaces<HLOInlinerInterface>(); addInterfaces<HLOInlinerInterface>();
addTypes<TokenType>(); addTypes<TokenType>();
// Support unknown operations because not all XLA operations are registered.
// allowUnknownOperations();
} }
Type MhloDialect::parseType(DialectAsmParser& parser) const { Type MhloDialect::parseType(DialectAsmParser& parser) const {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This file defines the operations used in the XLA dialect. // This file defines the operations used in the LMHLO dialect.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"

View File

@ -96,7 +96,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
// Check for "numpy"-style rank broadcast. // Check for "numpy"-style rank broadcast.
auto broadcast_dimensions = op.broadcast_dimensions(); auto broadcast_dimensions = op.broadcast_dimensions();
if (broadcast_dimensions && if (broadcast_dimensions &&
!xla::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) { !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
// Note: It is unclear whether the general specification of explicit // Note: It is unclear whether the general specification of explicit
// broadcast_dimensions on binary ops is a feature we want to carry // broadcast_dimensions on binary ops is a feature we want to carry
// forward. While it can technically be implemented for ranked-dynamic, // forward. While it can technically be implemented for ranked-dynamic,
@ -126,7 +126,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
Value result_extents = Value result_extents =
xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
rewriter); rewriter);
// Note that we unconditionally emit DynamicBroadcastInDim ops and let // Note that we unconditionally emit DynamicBroadcastInDim ops and let

View File

@ -53,5 +53,5 @@ struct TestChloLegalizeToHloPass
} // namespace mlir } // namespace mlir
static mlir::PassRegistration<mlir::chlo::TestChloLegalizeToHloPass> pass( static mlir::PassRegistration<mlir::chlo::TestChloLegalizeToHloPass> pass(
"test-xla-chlo-legalize-to-hlo", "mhlo-test-chlo-legalize-to-hlo",
"Test pass for applying chlo -> hlo legalization patterns"); "Test pass for applying chlo -> hlo legalization patterns");

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This file implements logic for lowering XLA dialect to Standard dialect. // This file implements logic for lowering MHLO dialect to Standard dialect.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
@ -107,8 +107,8 @@ LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
} }
LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
// Converts an XLA while loop into control flow. This generates a set of MLIR // Converts a MHLO while loop into control flow. This generates a set of MLIR
// blocks and branches, along with inlining the regions provided by the XLA // blocks and branches, along with inlining the regions provided by the MHLO
// while loop. The structure should be similar to below: // while loop. The structure should be similar to below:
// //
// <prior operations> // <prior operations>
@ -232,5 +232,5 @@ mlir::mhlo::createLegalizeControlFlowPass() {
} }
static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass( static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass(
"xla-legalize-control-flow", "mhlo-legalize-control-flow",
"Legalize from XLA control flow to MLIR control flow"); "Legalize from MHLO control flow to CFG control flow");

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir { namespace mlir {
namespace xla { namespace hlo {
namespace { namespace {
/// Emits the fast tanh approximation that is also used by XLA. /// Emits the fast tanh approximation that is also used by XLA.
@ -149,8 +149,8 @@ void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
} }
static PassRegistration<LegalizeTanhToApproximation> legalize_pass( static PassRegistration<LegalizeTanhToApproximation> legalize_pass(
"xla-legalize-tanh-to-approximation", "mhlo-legalize-tanh-to-approximation",
"Legalize tanh from standard dialect to an approximation"); "Legalize tanh from standard dialect to an approximation");
} // namespace xla } // namespace hlo
} // namespace mlir } // namespace mlir

View File

@ -32,7 +32,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir { namespace mlir {
@ -49,12 +49,12 @@ Value getResultValue(Operation* op) {
} }
template <bool isLHLO = true> template <bool isLHLO = true>
ShapedType getXLAOpResultType(Operation* op) { ShapedType getHloOpResultType(Operation* op) {
return getResultValue<isLHLO>(op).getType().template cast<ShapedType>(); return getResultValue<isLHLO>(op).getType().template cast<ShapedType>();
} }
template <bool isLHLO = true> template <bool isLHLO = true>
bool verifyXLAOpBufferOrTensorSemantics(Operation* op) { bool verifyHloOpBufferOrTensorSemantics(Operation* op) {
auto verifyType = [&](Value val) -> bool { auto verifyType = [&](Value val) -> bool {
return (isLHLO && val.getType().isa<MemRefType>()) || return (isLHLO && val.getType().isa<MemRefType>()) ||
(!isLHLO && val.getType().isa<RankedTensorType>()); (!isLHLO && val.getType().isa<RankedTensorType>());
@ -133,7 +133,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in lmhlo namespace. // TODO(ravishankarm) : For now use the method in lmhlo namespace.
// That method needs to be moved out of there. // That method needs to be moved out of there.
Value opResult = lmhlo::XlaOpToStdScalarOp::map<OpTy>( Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes, op, bodyResultTypes,
llvm::to_vector<2>(args.take_front(args_count)), &rewriter); llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, opResult); nestedBuilder.create<linalg::YieldOp>(loc, opResult);
@ -163,7 +163,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs()); auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs()); auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of lmhlo namespace. // TODO(ravishankarm) : Move this method out of lmhlo namespace.
Value opResult = lmhlo::XlaOpToStdScalarOp::map<LhloOp>( Value opResult = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs}, lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter); &rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out()); rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
@ -274,7 +274,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
} }
}; };
/// Base class for lowering xla operations that have one operand and one result, /// Base class for lowering HLO operations that have one operand and one result,
/// and are semantically equivalent to a copy of the input to the output (like /// and are semantically equivalent to a copy of the input to the output (like
/// transpose, some reshape, etc.). The derived classes need to provide a method /// transpose, some reshape, etc.). The derived classes need to provide a method
/// `getIndexingMaps` that returns AffineMaps for the index maps of the input /// `getIndexingMaps` that returns AffineMaps for the index maps of the input
@ -287,8 +287,8 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
OpTy op, ArrayRef<Value> args, OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure(); if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto resultType = getXLAOpResultType<isLHLO>(op); auto resultType = getHloOpResultType<isLHLO>(op);
SmallVector<AffineMap, 2> indexing_maps = SmallVector<AffineMap, 2> indexing_maps =
Derived::getIndexingMaps(op, &rewriter); Derived::getIndexingMaps(op, &rewriter);
@ -322,7 +322,7 @@ class BroadcastConverter
ShapedType inputType = ShapedType inputType =
broadcastOp.operand().getType().template cast<ShapedType>(); broadcastOp.operand().getType().template cast<ShapedType>();
unsigned inputRank = inputType.getRank(); unsigned inputRank = inputType.getRank();
unsigned nloops = getXLAOpResultType<isLHLO>(broadcastOp).getRank(); unsigned nloops = getHloOpResultType<isLHLO>(broadcastOp).getRank();
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
// the input's dimensions. // the input's dimensions.
@ -356,7 +356,7 @@ class HloBroadcastInDimConverter
static SmallVector<AffineMap, 2> getIndexingMaps( static SmallVector<AffineMap, 2> getIndexingMaps(
mhlo::BroadcastInDimOp broadcastOp, Builder* b) { mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
auto resultType = getXLAOpResultType<false>(broadcastOp); auto resultType = getHloOpResultType<false>(broadcastOp);
auto operandType = auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>(); broadcastOp.operand().getType().template cast<ShapedType>();
unsigned nloops = resultType.getRank(); unsigned nloops = resultType.getRank();
@ -555,7 +555,7 @@ class TransposeConverter
isLHLO>::DataMovementOpConverter; isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType = auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>(); getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank(); auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs; SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultType.getRank()); inputExprs.resize(resultType.getRank());
@ -579,11 +579,11 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
OpTy reshapeOp, ArrayRef<Value> args, OpTy reshapeOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(reshapeOp)) if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
return failure(); return failure();
ShapedType operandType = ShapedType operandType =
reshapeOp.operand().getType().template cast<ShapedType>(); reshapeOp.operand().getType().template cast<ShapedType>();
ShapedType resultType = getXLAOpResultType<isLHLO>(reshapeOp); ShapedType resultType = getHloOpResultType<isLHLO>(reshapeOp);
if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
return failure(); return failure();
@ -708,7 +708,7 @@ class ReverseConverter
isLHLO>::DataMovementOpConverter; isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType = auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>(); getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank(); auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs; SmallVector<AffineExpr, 2> inputExprs;
inputExprs.reserve(nloops); inputExprs.reserve(nloops);

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This file implements logic for lowering XLA dialect to Standard dialect. // This file implements logic for lowering MHLO dialect to Standard dialect.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
@ -187,8 +187,8 @@ std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
return std::make_unique<LegalizeToStandard>(); return std::make_unique<LegalizeToStandard>();
} }
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
mlir::MLIRContext *ctx) { mlir::MLIRContext *ctx) {
mlir::populateWithGenerated(ctx, patterns); mlir::populateWithGenerated(ctx, patterns);
patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx); patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
} }
@ -196,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
/// Perform the lowering to standard dialect. /// Perform the lowering to standard dialect.
void LegalizeToStandard::runOnFunction() { void LegalizeToStandard::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
mlir::mhlo::PopulateXlaToStdPatterns(&patterns, &getContext()); mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns); applyPatternsAndFoldGreedily(getFunction(), patterns);
} }
static PassRegistration<LegalizeToStandard> legalize_pass( static PassRegistration<LegalizeToStandard> legalize_pass(
"xla-legalize-to-std", "Legalize from XLA dialect to standard dialect"); "mhlo-legalize-to-std", "Legalize from MHLO dialect to standard dialect");
} // end namespace mhlo } // end namespace mhlo
} // end namespace mlir } // end namespace mlir

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This is the legalization pattern definition file for XLA to StandardOps. // This is the legalization pattern definition file for MHLO to StandardOps.
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td" include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td"

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {
@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices); auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
auto result = auto result =
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices); rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
Value op_result = lmhlo::XlaOpToStdScalarOp::map<DotOp>( Value op_result = lmhlo::HloOpToStdScalarOp::map<DotOp>(
op, element_type, {l, r, result}, &builder); op, element_type, {l, r, result}, &builder);
map_status = success(op_result != nullptr); map_status = success(op_result != nullptr);
if (failed(map_status)) return; if (failed(map_status)) return;
@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
ValueRange induction_vars) { ValueRange induction_vars) {
auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars); auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars); auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
Value op_result = lmhlo::XlaOpToStdScalarOp::map<LhloOpTy>( Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOpTy>(
op, element_type, {l, r}, &builder); op, element_type, {l, r}, &builder);
map_status = success(op_result != nullptr); map_status = success(op_result != nullptr);
if (failed(map_status)) return; if (failed(map_status)) return;

View File

@ -35,7 +35,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {

View File

@ -192,22 +192,22 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern; using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
lmhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/, lmhlo::ReduceOp reduce_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
// TODO(b/137624192) Implement variadic reduce. // TODO(b/137624192) Implement variadic reduce.
if (xla_reduce_op.out().size() != 1) return failure(); if (reduce_op.out().size() != 1) return failure();
scf::ReduceOp reduce_op = scf::ReduceOp scf_reduce_op =
CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter); CreateReduceOpInNestedParallelLoops(reduce_op, &rewriter);
ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op, ConvertToReductionOperator(reduce_op.getLoc(), scf_reduce_op,
&xla_reduce_op.body().front(), &rewriter); &reduce_op.body().front(), &rewriter);
rewriter.replaceOp(xla_reduce_op, llvm::None); rewriter.replaceOp(reduce_op, llvm::None);
return success(); return success();
} }
private: private:
// Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
// refers to the parallel dimensions of `xla_reduce_op` if any and the inner // refers to the parallel dimensions of `reduce_op` if any and the inner
// ParallelOp refers to the reduction dimensions. The scf.reduce op is // ParallelOp refers to the reduction dimensions. The scf.reduce op is
// returned. // returned.
// //
@ -226,16 +226,15 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
// scf.yield // scf.yield
// } // }
scf::ReduceOp CreateReduceOpInNestedParallelLoops( scf::ReduceOp CreateReduceOpInNestedParallelLoops(
lmhlo::ReduceOp xla_reduce_op, lmhlo::ReduceOp reduce_op, ConversionPatternRewriter* rewriter) const {
ConversionPatternRewriter* rewriter) const { auto loc = reduce_op.getLoc();
auto loc = xla_reduce_op.getLoc();
DenseSet<int> reducing_dims; DenseSet<int> reducing_dims;
for (const auto& rdim : xla_reduce_op.dimensions().getIntValues()) { for (const auto& rdim : reduce_op.dimensions().getIntValues()) {
reducing_dims.insert(rdim.getSExtValue()); reducing_dims.insert(rdim.getSExtValue());
} }
Value operand = *xla_reduce_op.operands().begin(); Value operand = *reduce_op.operands().begin();
Value out = *xla_reduce_op.out().begin(); Value out = *reduce_op.out().begin();
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step; SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step; SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
auto operand_shape = operand.getType().cast<MemRefType>().getShape(); auto operand_shape = operand.getType().cast<MemRefType>().getShape();
@ -252,7 +251,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
} }
// Load initial value from memref<element_type>. // Load initial value from memref<element_type>.
SmallVector<Value, 1> init_value = { SmallVector<Value, 1> init_value = {
rewriter->create<LoadOp>(loc, *xla_reduce_op.init_values().begin())}; rewriter->create<LoadOp>(loc, *reduce_op.init_values().begin())};
// Outer ParallelOp is not needed if it is a reduction across all dims. // Outer ParallelOp is not needed if it is a reduction across all dims.
scf::ParallelOp outer; scf::ParallelOp outer;
if (!parallel_lower.empty()) { if (!parallel_lower.empty()) {
@ -293,7 +292,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
rewriter->setInsertionPointToStart(inner.getBody()); rewriter->setInsertionPointToStart(inner.getBody());
Value elem = rewriter->create<mlir::LoadOp>( Value elem = rewriter->create<mlir::LoadOp>(
loc, *xla_reduce_op.operands().begin(), indices); loc, *reduce_op.operands().begin(), indices);
return rewriter->create<scf::ReduceOp>(loc, elem); return rewriter->create<scf::ReduceOp>(loc, elem);
} }
}; };
@ -364,42 +363,42 @@ class ReduceWindowOpConverter
using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern; using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
lmhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/, lmhlo::ReduceWindowOp reduce_window_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
scf::ParallelOp output_loop, window_loop; scf::ParallelOp output_loop, window_loop;
std::tie(output_loop, window_loop) = std::tie(output_loop, window_loop) =
CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op, CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op,
&rewriter); &rewriter);
scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
xla_reduce_window_op, output_loop, window_loop, &rewriter); reduce_window_op, output_loop, window_loop, &rewriter);
ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op, ConvertToReductionOperator(reduce_window_op.getLoc(), reduce_op,
&xla_reduce_window_op.body().front(), &rewriter); &reduce_window_op.body().front(), &rewriter);
rewriter.replaceOp(xla_reduce_window_op, llvm::None); rewriter.replaceOp(reduce_window_op, llvm::None);
return success(); return success();
} }
private: private:
std::pair<scf::ParallelOp, scf::ParallelOp> std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow( CreateParallelLoopsToTraverseOutputAndWindow(
lmhlo::ReduceWindowOp xla_reduce_window_op, lmhlo::ReduceWindowOp reduce_window_op,
ConversionPatternRewriter* rewriter) const { ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_window_op.getLoc(); auto loc = reduce_window_op.getLoc();
Value init_value = Value init_value =
rewriter->create<LoadOp>(loc, xla_reduce_window_op.init_value()); rewriter->create<LoadOp>(loc, reduce_window_op.init_value());
Value zero = rewriter->create<ConstantIndexOp>(loc, 0); Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
Value one = rewriter->create<ConstantIndexOp>(loc, 1); Value one = rewriter->create<ConstantIndexOp>(loc, 1);
// Create an outer parallel loop that spans the output of ReduceWindowOp. // Create an outer parallel loop that spans the output of ReduceWindowOp.
Value xla_output = xla_reduce_window_op.out(); Value output = reduce_window_op.out();
auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter); auto output_loop = MakeLoopOverShape(loc, output, rewriter);
// Create a nested loop that traverses the window. // Create a nested loop that traverses the window.
SmallVector<Value, 2> window_lower, window_upper, window_step; SmallVector<Value, 2> window_lower, window_upper, window_step;
rewriter->setInsertionPointToStart(output_loop.getBody()); rewriter->setInsertionPointToStart(output_loop.getBody());
for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) { for (const auto& window_dim : reduce_window_op.window_dimensions()) {
window_step.push_back(one); window_step.push_back(one);
window_lower.push_back(zero); window_lower.push_back(zero);
window_upper.push_back( window_upper.push_back(
@ -410,38 +409,38 @@ class ReduceWindowOpConverter
Value reduction_result = *window_loop.getResults().begin(); Value reduction_result = *window_loop.getResults().begin();
auto output_ivs = output_loop.getInductionVars(); auto output_ivs = output_loop.getInductionVars();
rewriter->create<StoreOp>(loc, reduction_result, xla_output, output_ivs); rewriter->create<StoreOp>(loc, reduction_result, output, output_ivs);
return std::make_pair(output_loop, window_loop); return std::make_pair(output_loop, window_loop);
} }
scf::ReduceOp CreateReduceOpInNestedParallelLoops( scf::ReduceOp CreateReduceOpInNestedParallelLoops(
lmhlo::ReduceWindowOp xla_reduce_window_op, scf::ParallelOp output_loop, lmhlo::ReduceWindowOp reduce_window_op, scf::ParallelOp output_loop,
scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const { scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
rewriter->setInsertionPointToStart(window_loop.getBody()); rewriter->setInsertionPointToStart(window_loop.getBody());
auto loc = xla_reduce_window_op.getLoc(); auto loc = reduce_window_op.getLoc();
if (xla_reduce_window_op.base_dilations().hasValue() || if (reduce_window_op.base_dilations().hasValue() ||
xla_reduce_window_op.window_dilations().hasValue()) { reduce_window_op.window_dilations().hasValue()) {
xla_reduce_window_op.emitRemark( reduce_window_op.emitRemark(
"Lowering to parallel loops does not support `base_dilations` or " "Lowering to parallel loops does not support `base_dilations` or "
"`window_dilations` attributes yet. The attributes will be ignored."); "`window_dilations` attributes yet. The attributes will be ignored.");
} }
Value xla_operand = xla_reduce_window_op.operand(); Value operand = reduce_window_op.operand();
auto xla_operand_type = xla_operand.getType().cast<MemRefType>(); auto operand_type = operand.getType().cast<MemRefType>();
// Compute ivs in 'arg' buffer and whether these ivs are in pad area or not. // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
MappedIvs mapped_ivs = MapWindowIvsToInput( MappedIvs mapped_ivs =
xla_reduce_window_op, output_loop.getInductionVars(), MapWindowIvsToInput(reduce_window_op, output_loop.getInductionVars(),
window_loop.getInductionVars(), rewriter); window_loop.getInductionVars(), rewriter);
auto elem_or_init = rewriter->create<scf::IfOp>( auto elem_or_init = rewriter->create<scf::IfOp>(
loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds, loc, operand_type.getElementType(), mapped_ivs.in_bounds,
/*withElseRegion=*/true); /*withElseRegion=*/true);
OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
Value elem = then_builder.create<mlir::LoadOp>( Value elem = then_builder.create<mlir::LoadOp>(
loc, xla_reduce_window_op.operand(), mapped_ivs.ivs); loc, reduce_window_op.operand(), mapped_ivs.ivs);
then_builder.create<scf::YieldOp>(loc, elem); then_builder.create<scf::YieldOp>(loc, elem);
OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); OpBuilder else_builder = elem_or_init.getElseBodyBuilder();

View File

@ -45,13 +45,13 @@ class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> {
public: public:
explicit LowerComplex() : PassWrapper<LowerComplex, FunctionPass>() {} explicit LowerComplex() : PassWrapper<LowerComplex, FunctionPass>() {}
/// Performs the lowering to XLA dialect. /// Performs the lowering to MHLO dialect.
void runOnFunction() override; void runOnFunction() override;
}; };
} // end anonymous namespace } // end anonymous namespace
namespace mlir { namespace mlir {
namespace xla { namespace hlo {
namespace { namespace {
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc" #include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc"
@ -62,18 +62,18 @@ void PopulateComplexLoweringPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
populateWithGenerated(context, patterns); populateWithGenerated(context, patterns);
} }
} // end namespace xla } // end namespace hlo
} // end namespace mlir } // end namespace mlir
// Lowers the complex operations that can be represented using other operations. // Lowers the complex operations that can be represented using other operations.
void LowerComplex::runOnFunction() { void LowerComplex::runOnFunction() {
// Add lowering patterns to the list. // Add lowering patterns to the list.
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns); mlir::hlo::PopulateComplexLoweringPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns); applyPatternsAndFoldGreedily(getFunction(), patterns);
} }
static PassRegistration<LowerComplex> pass( static PassRegistration<LowerComplex> pass(
"test-xla-lower-complex", "mhlo-test-lower-complex",
"Lower complex operations into non-complex operations"); "Lower complex operations into non-complex operations");

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This file implements logic for lowering XLA general dot to a regular dot. // This file implements logic for lowering MHLO general dot to a regular dot.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
@ -188,5 +188,5 @@ void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
} }
static PassRegistration<LegalizeGeneralDot> legalize_pass( static PassRegistration<LegalizeGeneralDot> legalize_pass(
"test-xla-lower-general-dot", "mhlo-test-lower-general-dot",
"Tests lowering general dot to a non-batched dot when possible"); "Tests lowering general dot to a non-batched dot when possible");

View File

@ -54,5 +54,5 @@ struct TestMaterializeBroadcastsPass
} // namespace mlir } // namespace mlir
static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass( static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass(
"test-xla-materialize-broadcasts", "mhlo-test-materialize-broadcasts",
"Test pass for materializing 'broadcast_dimensions' attributes"); "Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -479,7 +479,7 @@ class FusionPlanner {
EquivalenceClasses<int32_t> leader_for_node_; EquivalenceClasses<int32_t> leader_for_node_;
}; };
struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> { struct MhloFusion : public mlir::PassWrapper<MhloFusion, FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
FuncOp func = getFunction(); FuncOp func = getFunction();
if (!IsTargetFunc(func)) { if (!IsTargetFunc(func)) {
@ -568,12 +568,12 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() { std::unique_ptr<OperationPass<FuncOp>> createMhloFusion() {
return std::make_unique<XlaHloFusion>(); return std::make_unique<MhloFusion>();
} }
static PassRegistration<XlaHloFusion> mhlo_fusion_pass( static PassRegistration<MhloFusion> mhlo_fusion_pass(
"xla-hlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns."); "mhlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
} // namespace mhlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -71,7 +71,7 @@ class SinkConstantsToControlFlow
}; };
static mlir::PassRegistration<SinkConstantsToControlFlow> pass( static mlir::PassRegistration<SinkConstantsToControlFlow> pass(
"xla-hlo-sink-constants-to-control-flow", "mhlo-sink-constants-to-control-flow",
"Sink constants implicitly captured in control flow regions. This is " "Sink constants implicitly captured in control flow regions. This is "
"necessary to export to XLA."); "necessary to export to XLA.");

View File

@ -22,12 +22,12 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace xla { namespace hlo {
namespace { namespace {
struct InferReturnTypeComponentsPattern : public RewritePattern { struct InferReturnTypeComponentsPattern : public RewritePattern {
InferReturnTypeComponentsPattern(MLIRContext *context) InferReturnTypeComponentsPattern(MLIRContext *context)
: RewritePattern("xla_test.get_return_type_components", 1, context) {} : RewritePattern("mhlo_test.get_return_type_components", 1, context) {}
LogicalResult matchAndRewrite(Operation *op, LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1) return failure(); if (op->getNumOperands() != 1) return failure();
@ -44,7 +44,7 @@ struct InferReturnTypeComponentsPattern : public RewritePattern {
} }
// Replace the op with another pass-through op with attributes added. // Replace the op with another pass-through op with attributes added.
OperationState state(op->getLoc(), "xla_test.return_type_components", OperationState state(op->getLoc(), "mhlo_test.return_type_components",
op->getOperands(), op->getResultTypes(), op->getOperands(), op->getResultTypes(),
op->getAttrs()); op->getAttrs());
auto new_op = rewriter.createOperation(state); auto new_op = rewriter.createOperation(state);
@ -65,7 +65,7 @@ struct InferReturnTypeComponentsPattern : public RewritePattern {
struct ReifyReturnTypeShapesPattern : public RewritePattern { struct ReifyReturnTypeShapesPattern : public RewritePattern {
ReifyReturnTypeShapesPattern(MLIRContext *context) ReifyReturnTypeShapesPattern(MLIRContext *context)
: RewritePattern("xla_test.reify_return_type_shapes", 1, context) {} : RewritePattern("mhlo_test.reify_return_type_shapes", 1, context) {}
LogicalResult matchAndRewrite(Operation *op, LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1) return failure(); if (op->getNumOperands() != 1) return failure();
@ -92,9 +92,9 @@ struct TestInferShapedTypeMethodsPass
}; };
} // namespace } // namespace
} // namespace xla } // namespace hlo
} // namespace mlir } // namespace mlir
static mlir::PassRegistration<mlir::xla::TestInferShapedTypeMethodsPass> pass( static mlir::PassRegistration<mlir::hlo::TestInferShapedTypeMethodsPass> pass(
"test-xla-infer-shaped-type-methods", "mhlo-test-infer-shaped-type-methods",
"Uses test ops to invoke InferShapedTypeOpInterface methods"); "Uses test ops to invoke InferShapedTypeOpInterface methods");

View File

@ -42,5 +42,5 @@ struct TestUnfuseBatchNormPass
} // namespace mlir } // namespace mlir
static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass( static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass(
"test-xla-unfuse-batch-norm", "mhlo-test-unfuse-batch-norm",
"Test pass for materializing 'broadcast_dimensions' attributes"); "Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
namespace mlir { namespace mlir {
namespace xla { namespace hlo {
bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
DenseIntElementsAttr broadcast_dims) { DenseIntElementsAttr broadcast_dims) {
@ -70,5 +70,5 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
result_shape_v); result_shape_v);
} }
} // namespace xla } // namespace hlo
} // namespace mlir } // namespace mlir

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
namespace mlir { namespace mlir {
namespace xla { namespace hlo {
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type) { mlir::Type new_type) {
@ -82,5 +82,5 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
})); }));
} }
} // namespace xla } // namespace hlo
} // namespace mlir } // namespace mlir

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
namespace mlir { namespace mlir {
namespace xla { namespace hlo {
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y, DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y,
bool allow_empty) { bool allow_empty) {
@ -66,5 +66,5 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
return DenseElementsAttr::get(scalar_ty, value); return DenseElementsAttr::get(scalar_ty, value);
} }
} // namespace xla } // namespace hlo
} // namespace mlir } // namespace mlir

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s // RUN: mlir-hlo-opt -mhlo-test-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// CHECK-LABEL: @broadcast_add // CHECK-LABEL: @broadcast_add
// Note that all broadcast_ops are expanded from the same template, so // Note that all broadcast_ops are expanded from the same template, so
@ -12,7 +12,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]] // CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
// CHECK: return %[[EXTENTS]] // CHECK: return %[[EXTENTS]]
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex> %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
return %1 : tensor<1xindex> return %1 : tensor<1xindex>
} }
@ -20,8 +20,8 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
// CHECK-LABEL: @complex_ranked_components // CHECK-LABEL: @complex_ranked_components
func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> { func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>} // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
return %1 : tensor<?x?xcomplex<f32>> return %1 : tensor<?x?xcomplex<f32>>
} }
@ -29,8 +29,8 @@ func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
// CHECK-LABEL: @compare_ranked_components // CHECK-LABEL: @compare_ranked_components
func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> { func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1> %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1} // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1> %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
return %0 : tensor<?x?xi1> return %0 : tensor<?x?xi1>
} }
@ -38,8 +38,8 @@ func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
// CHECK-LABEL: @broadcast_add_ranked_components_r1 // CHECK-LABEL: @broadcast_add_ranked_components_r1
func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> { func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32} // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32> %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
return %1 : tensor<?xf32> return %1 : tensor<?xf32>
} }
@ -49,8 +49,8 @@ func @broadcast_add_ranked_components_r1x2(%arg0: tensor<?xf32>, %arg1: tensor<?
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32> %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
// TODO: Overly broad shapes are being returned. Tighten the calculation // TODO: Overly broad shapes are being returned. Tighten the calculation
// and update/extend these tests. // and update/extend these tests.
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32} // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x3xf32>) -> tensor<?x3xf32> %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?x3xf32>) -> tensor<?x3xf32>
return %1 : tensor<?x3xf32> return %1 : tensor<?x3xf32>
} }

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s // RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a // Check the non-broadcast case for each registered op, then just check a
// representative op for detailed broadcast semantics. // representative op for detailed broadcast semantics.

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -xla-legalize-control-flow %s -o - | FileCheck %s // RUN: mlir-hlo-opt -mhlo-legalize-control-flow %s -o - | FileCheck %s
// CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> { // CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> {
func @while(%arg0: tensor<i64>) -> tensor<i64> { func @while(%arg0: tensor<i64>) -> tensor<i64> {

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -xla-legalize-to-std %s -o - | FileCheck %s // RUN: mlir-hlo-opt -mhlo-legalize-to-std %s -o - | FileCheck %s
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s // RUN: mlir-hlo-opt -mhlo-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
func @tanh_f64(%arg0 : f64) -> f64 { func @tanh_f64(%arg0 : f64) -> f64 {
%res = tanh %arg0 : f64 %res = tanh %arg0 : f64

View File

@ -1,6 +1,6 @@
// GenericAtomicRMWOp should contain only ops with no side effects. // GenericAtomicRMWOp should contain only ops with no side effects.
// Unfortunately, the legalization pattern for SelectAndScatterOp has to adapt // Unfortunately, the legalization pattern for SelectAndScatterOp has to adapt
// to XLA LHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body. // to LMHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body.
// Lowering to STD dialect and store forwarding pass would be required to get // Lowering to STD dialect and store forwarding pass would be required to get
// rid of them. This is exactly what is done in the real MLIR GPU pipeline, but // rid of them. This is exactly what is done in the real MLIR GPU pipeline, but
// here we disable verification with `verify-each=0` to check the output IR. // here we disable verification with `verify-each=0` to check the output IR.

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s // RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
// CHECK-LABEL: @add // CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s // RUN: mlir-hlo-opt -mhlo-test-lower-general-dot -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: @testDebatch1 // CHECK-LABEL: @testDebatch1
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> { func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck %s // RUN: mlir-hlo-opt -mhlo-test-materialize-broadcasts -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: @clampBroadcast // CHECK-LABEL: @clampBroadcast
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>) // CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s // RUN: mlir-hlo-opt %s -mhlo-fusion -split-input-file | FileCheck %s
// CHECK-LABEL: func @multi_outputs_same // CHECK-LABEL: func @multi_outputs_same
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s // RUN: mlir-hlo-opt %s -mhlo-sink-constants-to-control-flow | FileCheck %s
// Tests sinking constants to a while loop. // Tests sinking constants to a while loop.

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s // RUN: mlir-hlo-opt -split-input-file -mhlo-test-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s
// CHECK-LABEL: @batchNormInference_2D_inner_features // CHECK-LABEL: @batchNormInference_2D_inner_features
// CHECK-SAME: %[[X:[^:[:space:]]+]] // CHECK-SAME: %[[X:[^:[:space:]]+]]