Cleanup build rule names in compiler/mlir/hlo to remove the redundant/obsolete xla_ prefix

PiperOrigin-RevId: 320320140
This commit is contained in:
Mehdi Amini 2020-07-09 03:32:16 +00:00 committed by Mehdi Amini
parent f4303855c4
commit 506ddd9c4a
51 changed files with 222 additions and 228 deletions

View File

@ -38,7 +38,7 @@ def HLOClient_Dialect : Dialect {
let name = "chlo";
let cppNamespace = "chlo";
let summary = [{
XLA Client HLO Ops
Client HLO Ops
}];
let description = [{
@ -60,7 +60,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
}
//===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions.
// CHLO binary elementwise op definitions.
// From the client perspective, each of these support both explicit rank
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
// shape broadcasting.

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the XLA dialect.
// This file defines the operations used in the MHLO dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_

View File

@ -13,10 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for XLA HLO ops which map to the
// traditional definition in xla_data.proto (or are aligned with the goals
// thereof).
// See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto
// This is the operation definition file for MHLO ops.
#ifndef HLO_OPS
#define HLO_OPS
@ -44,7 +41,7 @@ class HLO_Op<string mnemonic, list<OpTrait> traits> :
}
//===----------------------------------------------------------------------===//
// XLA nullary op definitions.
// MHLO nullary op definitions.
//===----------------------------------------------------------------------===//
def HLO_ConstOp : HLO_Op<"constant",
@ -113,7 +110,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
}
//===----------------------------------------------------------------------===//
// XLA unary elementwise op definitions.
// MHLO unary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
@ -264,7 +261,7 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
HLO_FpOrComplexTensor>, BASE_HLO_TanhOp;
//===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions.
// MHLO binary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
@ -363,7 +360,7 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract",
}
//===----------------------------------------------------------------------===//
// XLA binary logical elementwise op definitions.
// MHLO binary logical elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
@ -381,7 +378,7 @@ def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp;
def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
//===----------------------------------------------------------------------===//
// XLA communication op definitions.
// MHLO communication op definitions.
//===----------------------------------------------------------------------===//
// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'.
@ -481,7 +478,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
}
//===----------------------------------------------------------------------===//
// XLA parallelism related op definitions.
// MHLO parallelism related op definitions.
//===----------------------------------------------------------------------===//
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
@ -492,7 +489,7 @@ def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
}
//===----------------------------------------------------------------------===//
// XLA control flow op definitions.
// MHLO control flow op definitions.
//===----------------------------------------------------------------------===//
def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> {
@ -640,7 +637,7 @@ def HLO_ReduceOp: HLO_Op<"reduce", [
}
//===----------------------------------------------------------------------===//
// XLA tuple op definitions.
// MHLO tuple op definitions.
//===----------------------------------------------------------------------===//
def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO_GetTupleElementOp {
let arguments = (ins
@ -684,7 +681,7 @@ def HLO_CompareOp: HLO_Op<"compare",
}
//===----------------------------------------------------------------------===//
// XLA Slice definitions.
// MHLO Slice definitions.
//===----------------------------------------------------------------------===//
def HLO_SliceOp: HLO_Op<
@ -745,7 +742,7 @@ def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice",
//===----------------------------------------------------------------------===//
// XLA Other op definitions.
// MHLO Other op definitions.
//===----------------------------------------------------------------------===//
def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>,
@ -1320,7 +1317,7 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> {
}
//===----------------------------------------------------------------------===//
// XLA RngUniform Operator.
// MHLO RngUniform Operator.
//===----------------------------------------------------------------------===//
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
let arguments = (ins
@ -1347,7 +1344,7 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
}
//===----------------------------------------------------------------------===//
// XLA Quantize Operator.
// MHLO Quantize Operator.
//===----------------------------------------------------------------------===//
def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>,
BASE_HLO_DequantizeOp {

View File

@ -35,7 +35,7 @@ def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
defvar BroadcastDimAttr = I64ElementsAttr;
//===----------------------------------------------------------------------===//
// XLA on tensors type definitions.
// MHLO on tensors type definitions.
//===----------------------------------------------------------------------===//
// Token type.
@ -78,7 +78,7 @@ def HLO_StaticShapeTensor : StaticShapeTensorOf<[
AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
//===----------------------------------------------------------------------===//
// XLA on tensors combined type definitions.
// MHLO on tensors combined type definitions.
//===----------------------------------------------------------------------===//
// Any integer or floating-point tensor types
@ -97,7 +97,7 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
//===----------------------------------------------------------------------===//
// XLA nullary op definitions.
// MHLO nullary op definitions.
//===----------------------------------------------------------------------===//
class BASE_HLO_ConstOp {
@ -117,7 +117,7 @@ class BASE_HLO_IotaOp {
}
//===----------------------------------------------------------------------===//
// XLA unary elementwise op definitions.
// MHLO unary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions

View File

@ -25,19 +25,19 @@ def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
class ConstantSplat<string value> : NativeCodeCall<
"xla::getSplat(&$_builder, $0, " # value # ")">;
"hlo::getSplat(&$_builder, $0, " # value # ")">;
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
def BinBroadcastDimensions : NativeCodeCall<
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
"hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
def BinBroadcastDimensionsNonEmpty : NativeCodeCall<
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">;
"hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">;
// Here, the element type can be any integer or float type. But, note that only
// 32 bit integers are supported for the value.
class GetScalarOfType<int value> : NativeCodeCall<
"xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
"hlo::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
#endif // HLO_UTILS

View File

@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for LXLA.
// This is the operation definition file for LMHLO, the "late" MHLO variant of
// the dialect, which operates on buffers instead of tensors.
//
// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to
// This file largely overlaps with mhlo_ops.td at a logic level. It's tempting to
// merge these two files together, but we need to consider the following
// obstacles:
// * We need to have a common representation for arguments. That is to say,
@ -43,7 +44,7 @@ def LHLO_Dialect : Dialect {
}
//===----------------------------------------------------------------------===//
// XLA type definitions.
// LMHLO type definitions.
//===----------------------------------------------------------------------===//
// Any integer tensor types
@ -66,7 +67,7 @@ def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
//===----------------------------------------------------------------------===//
// XLA nullary op definitions.
// LMHLO nullary op definitions.
//===----------------------------------------------------------------------===//
class LHLO_Op<string mnemonic, list<OpTrait> traits> :
@ -86,7 +87,7 @@ def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
}
//===----------------------------------------------------------------------===//
// XLA unary elementwise op definitions.
// LMHLO unary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
@ -157,7 +158,7 @@ def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HL
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp;
//===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions.
// LMHLO binary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
@ -212,7 +213,7 @@ def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp;
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp;
//===----------------------------------------------------------------------===//
// XLA control flow op definitions.
// LMHLO control flow op definitions.
//===----------------------------------------------------------------------===//
// TODO(b/139813999): specify required function signature in a type-safe way.
@ -284,7 +285,7 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
}
//===----------------------------------------------------------------------===//
// XLA tuple op definitions.
// LMHLO tuple op definitions.
//===----------------------------------------------------------------------===//
def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
@ -298,7 +299,7 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
}
//===----------------------------------------------------------------------===//
// XLA Slice definitions.
// LMHLO Slice definitions.
//===----------------------------------------------------------------------===//
def LHLO_SliceOp: LHLO_Op<
@ -483,7 +484,7 @@ def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
//===----------------------------------------------------------------------===//
// XLA Other op definitions.
// LMHLO Other op definitions.
//===----------------------------------------------------------------------===//
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
@ -150,15 +150,14 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
}
template <typename PredicateType>
inline Optional<PredicateType> getCmpPredicate(
StringRef xla_comparison_direction) {
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
return llvm::None;
}
template <>
inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
StringRef xla_comparison_direction) {
return llvm::StringSwitch<Optional<CmpFPredicate>>(xla_comparison_direction)
StringRef comparison_direction) {
return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
.Case("EQ", CmpFPredicate::OEQ)
.Case("NE", CmpFPredicate::ONE)
.Case("GE", CmpFPredicate::OGE)
@ -170,8 +169,8 @@ inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
template <>
inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
StringRef xla_comparison_direction) {
return llvm::StringSwitch<Optional<CmpIPredicate>>(xla_comparison_direction)
StringRef comparison_direction) {
return llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
.Case("EQ", CmpIPredicate::eq)
.Case("NE", CmpIPredicate::ne)
.Case("GE", CmpIPredicate::sge)
@ -181,8 +180,8 @@ inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
.Default(llvm::None);
}
template <typename XLACompareOpTy>
inline Value MapXlaCompareOpToStdScalarOp(Location loc,
template <typename CompareOpTy>
inline Value MapCompareOpToStdScalarOp(Location loc,
StringRef comparison_direction,
ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
@ -193,14 +192,14 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc,
Optional<CmpIPredicate> predicate =
getCmpPredicate<CmpIPredicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarIOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
return b->create<ScalarIOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
rhs);
}
if (element_type.isa<FloatType>()) {
Optional<CmpFPredicate> predicate =
getCmpPredicate<CmpFPredicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarFOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
rhs);
}
return nullptr;
@ -337,10 +336,10 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
loc, result_types, args, b);
}
/// Implements the conversion of XLA op to scalar op (to use within region of a
/// Implements the conversion of HLO op to scalar op (to use within region of a
/// linalg.generic op) for compare-select style operations like min/max.
template <typename... Args>
struct XlaCompareSelectOpToStdScalarOp {
struct CompareSelectOpToStdScalarOp {
static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
@ -352,7 +351,7 @@ struct XlaCompareSelectOpToStdScalarOp {
/// dialect with a given predicate based on the element type of the operand.
template <typename SupportedType, typename StdCompareOp, typename Predicate,
typename... Args>
struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
Args...> {
static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args,
@ -365,8 +364,8 @@ struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
args[0], args[1]);
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
}
return XlaCompareSelectOpToStdScalarOp<Args...>::map(
loc, comparison_direction, result_types, args, b);
return CompareSelectOpToStdScalarOp<Args...>::map(loc, comparison_direction,
result_types, args, b);
}
};
@ -384,7 +383,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
return CompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types,
args, b);
@ -395,7 +394,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
return CompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types,
args, b);
@ -475,25 +474,25 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
} // namespace impl
struct XlaOpToStdScalarOp {
struct HloOpToStdScalarOp {
// Implementation for LHLO ops except lmhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
template <typename HloOpTy, typename LhloOpTy = HloOpTy,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
static Value map(HloOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
args, b);
}
// Implementation for HLO ops except mhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
template <typename HloOpTy, typename LhloOpTy = mhlo::HloToLhloOp<HloOpTy>,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
static Value map(HloOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
args, b);
@ -505,7 +504,7 @@ struct XlaOpToStdScalarOp {
static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
@ -516,7 +515,7 @@ struct XlaOpToStdScalarOp {
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
};
@ -524,4 +523,4 @@ struct XlaOpToStdScalarOp {
} // namespace lmhlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_

View File

@ -56,7 +56,7 @@ std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
// fuse mhlo ops to kLoop/kInput fusion patterns
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
} // namespace mhlo
@ -94,12 +94,12 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
} // namespace lmhlo
namespace xla {
namespace hlo {
/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
} // namespace xla
} // namespace hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_

View File

@ -38,7 +38,7 @@ void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
void PopulateComplexLoweringPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
@ -93,14 +93,14 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
} // namespace chlo
namespace xla {
namespace hlo {
// Populates a pattern that translates the standard TanhOp to an approximation
// that does not use intrinsics.
void PopulateTanhToApproximationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace xla
} // namespace hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_
// Utilities relating to implementing HLO broadcasting.
// Note: This file should not depend on any non-MLIR TensorFlow libraries.
@ -27,7 +27,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
namespace mlir {
namespace xla {
namespace hlo {
// Checks whether the given operand types and broadcast_dims attr represent a
// legal combination for "numpy" style broadcasting (where 1-dims are prepended
@ -43,7 +43,7 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
Value rhs,
OpBuilder& builder);
} // namespace xla
} // namespace hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_

View File

@ -13,21 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
namespace mlir {
namespace xla {
namespace hlo {
// Converts the given elements attr to the specified elements type.
// Requires type of the elements and new_type to be either integer or float
// type.
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type);
} // namespace xla
} // namespace hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_
#include <vector>
@ -162,4 +162,4 @@ class GraphCycles {
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
namespace mlir {
namespace xla {
namespace hlo {
// Computes the broadcast dimensions attr for an elementwise binary operator
// between two ranked tensors.
@ -68,7 +68,7 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
// Requires `ty` to be either FloatType of IntegerType.
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
} // namespace xla
} // namespace hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_

View File

@ -137,7 +137,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
.dyn_cast_or_null<DenseIntElementsAttr>();
if (broadcast_dimensions &&
!xla::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
!hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
// Note: It is unclear whether the general specification of explicit
// broadcast_dimensions on binary ops is a feature we want to carry
// forward. While it can technically be implemented for ranked-dynamic,
@ -150,7 +150,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
<< "broadcast_dimensions = " << broadcast_dimensions;
}
Value computed_shape = xla::ComputeBinaryElementwiseBroadcastingResultExtents(
Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
loc, lhs, rhs, builder);
if (!computed_shape) return failure();
reifiedReturnShapes.push_back(computed_shape);

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
// Static initialization for XLA dialect registration.
// Static initialization for *HLO dialects registration.
static mlir::DialectRegistration<mlir::mhlo::MhloDialect> mhlo_ops;
static mlir::DialectRegistration<mlir::chlo::HloClientDialect> chlo_ops;
static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops;

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the XLA dialect.
// This file defines the operations used in the MHLO dialect.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
@ -404,7 +404,7 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
// If the operand is constant, we can do the conversion now.
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
return xla::ConvertElementsAttr(elementsAttr,
return hlo::ConvertElementsAttr(elementsAttr,
getElementTypeOrSelf(getResult()));
}
@ -2135,8 +2135,6 @@ MhloDialect::MhloDialect(MLIRContext* context)
>();
addInterfaces<HLOInlinerInterface>();
addTypes<TokenType>();
// Support unknown operations because not all XLA operations are registered.
// allowUnknownOperations();
}
Type MhloDialect::parseType(DialectAsmParser& parser) const {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the XLA dialect.
// This file defines the operations used in the LMHLO dialect.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"

View File

@ -96,7 +96,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
// Check for "numpy"-style rank broadcast.
auto broadcast_dimensions = op.broadcast_dimensions();
if (broadcast_dimensions &&
!xla::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
!hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
// Note: It is unclear whether the general specification of explicit
// broadcast_dimensions on binary ops is a feature we want to carry
// forward. While it can technically be implemented for ranked-dynamic,
@ -126,7 +126,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
Value result_extents =
xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
rewriter);
// Note that we unconditionally emit DynamicBroadcastInDim ops and let

View File

@ -53,5 +53,5 @@ struct TestChloLegalizeToHloPass
} // namespace mlir
static mlir::PassRegistration<mlir::chlo::TestChloLegalizeToHloPass> pass(
"test-xla-chlo-legalize-to-hlo",
"mhlo-test-chlo-legalize-to-hlo",
"Test pass for applying chlo -> hlo legalization patterns");

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering XLA dialect to Standard dialect.
// This file implements logic for lowering MHLO dialect to Standard dialect.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
@ -107,8 +107,8 @@ LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
}
LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
// Converts an XLA while loop into control flow. This generates a set of MLIR
// blocks and branches, along with inlining the regions provided by the XLA
// Converts a MHLO while loop into control flow. This generates a set of MLIR
// blocks and branches, along with inlining the regions provided by the MHLO
// while loop. The structure should be similar to below:
//
// <prior operations>
@ -232,5 +232,5 @@ mlir::mhlo::createLegalizeControlFlowPass() {
}
static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass(
"xla-legalize-control-flow",
"Legalize from XLA control flow to MLIR control flow");
"mhlo-legalize-control-flow",
"Legalize from MHLO control flow to CFG control flow");

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla {
namespace hlo {
namespace {
/// Emits the fast tanh approximation that is also used by XLA.
@ -149,8 +149,8 @@ void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
}
static PassRegistration<LegalizeTanhToApproximation> legalize_pass(
"xla-legalize-tanh-to-approximation",
"mhlo-legalize-tanh-to-approximation",
"Legalize tanh from standard dialect to an approximation");
} // namespace xla
} // namespace hlo
} // namespace mlir

View File

@ -32,7 +32,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
@ -49,12 +49,12 @@ Value getResultValue(Operation* op) {
}
template <bool isLHLO = true>
ShapedType getXLAOpResultType(Operation* op) {
ShapedType getHloOpResultType(Operation* op) {
return getResultValue<isLHLO>(op).getType().template cast<ShapedType>();
}
template <bool isLHLO = true>
bool verifyXLAOpBufferOrTensorSemantics(Operation* op) {
bool verifyHloOpBufferOrTensorSemantics(Operation* op) {
auto verifyType = [&](Value val) -> bool {
return (isLHLO && val.getType().isa<MemRefType>()) ||
(!isLHLO && val.getType().isa<RankedTensorType>());
@ -133,7 +133,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
// That method needs to be moved out of there.
Value opResult = lmhlo::XlaOpToStdScalarOp::map<OpTy>(
Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes,
llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
@ -163,7 +163,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
Value opResult = lmhlo::XlaOpToStdScalarOp::map<LhloOp>(
Value opResult = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
@ -274,7 +274,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
}
};
/// Base class for lowering xla operations that have one operand and one result,
/// Base class for lowering HLO operations that have one operand and one result,
/// and are semantically equivalent to a copy of the input to the output (like
/// transpose, some reshape, etc.). The derived classes need to provide a method
/// `getIndexingMaps` that returns AffineMaps for the index maps of the input
@ -287,8 +287,8 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
LogicalResult matchAndRewrite(
OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto resultType = getXLAOpResultType<isLHLO>(op);
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto resultType = getHloOpResultType<isLHLO>(op);
SmallVector<AffineMap, 2> indexing_maps =
Derived::getIndexingMaps(op, &rewriter);
@ -322,7 +322,7 @@ class BroadcastConverter
ShapedType inputType =
broadcastOp.operand().getType().template cast<ShapedType>();
unsigned inputRank = inputType.getRank();
unsigned nloops = getXLAOpResultType<isLHLO>(broadcastOp).getRank();
unsigned nloops = getHloOpResultType<isLHLO>(broadcastOp).getRank();
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
// the input's dimensions.
@ -356,7 +356,7 @@ class HloBroadcastInDimConverter
static SmallVector<AffineMap, 2> getIndexingMaps(
mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
auto resultType = getXLAOpResultType<false>(broadcastOp);
auto resultType = getHloOpResultType<false>(broadcastOp);
auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>();
unsigned nloops = resultType.getRank();
@ -555,7 +555,7 @@ class TransposeConverter
isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultType.getRank());
@ -579,11 +579,11 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
LogicalResult matchAndRewrite(
OpTy reshapeOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
return failure();
ShapedType operandType =
reshapeOp.operand().getType().template cast<ShapedType>();
ShapedType resultType = getXLAOpResultType<isLHLO>(reshapeOp);
ShapedType resultType = getHloOpResultType<isLHLO>(reshapeOp);
if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
@ -708,7 +708,7 @@ class ReverseConverter
isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.reserve(nloops);

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering XLA dialect to Standard dialect.
// This file implements logic for lowering MHLO dialect to Standard dialect.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
@ -187,7 +187,7 @@ std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
return std::make_unique<LegalizeToStandard>();
}
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
mlir::MLIRContext *ctx) {
mlir::populateWithGenerated(ctx, patterns);
patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
@ -196,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
/// Perform the lowering to standard dialect.
void LegalizeToStandard::runOnFunction() {
OwningRewritePatternList patterns;
mlir::mhlo::PopulateXlaToStdPatterns(&patterns, &getContext());
mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
static PassRegistration<LegalizeToStandard> legalize_pass(
"xla-legalize-to-std", "Legalize from XLA dialect to standard dialect");
"mhlo-legalize-to-std", "Legalize from MHLO dialect to standard dialect");
} // end namespace mhlo
} // end namespace mlir

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the legalization pattern definition file for XLA to StandardOps.
// This is the legalization pattern definition file for MHLO to StandardOps.
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td"

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
namespace mlir {
namespace lmhlo {
@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
auto result =
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
Value op_result = lmhlo::XlaOpToStdScalarOp::map<DotOp>(
Value op_result = lmhlo::HloOpToStdScalarOp::map<DotOp>(
op, element_type, {l, r, result}, &builder);
map_status = success(op_result != nullptr);
if (failed(map_status)) return;
@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
ValueRange induction_vars) {
auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
Value op_result = lmhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOpTy>(
op, element_type, {l, r}, &builder);
map_status = success(op_result != nullptr);
if (failed(map_status)) return;

View File

@ -35,7 +35,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
namespace mlir {
namespace lmhlo {

View File

@ -192,22 +192,22 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
lmhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
lmhlo::ReduceOp reduce_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
// TODO(b/137624192) Implement variadic reduce.
if (xla_reduce_op.out().size() != 1) return failure();
if (reduce_op.out().size() != 1) return failure();
scf::ReduceOp reduce_op =
CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter);
ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op,
&xla_reduce_op.body().front(), &rewriter);
rewriter.replaceOp(xla_reduce_op, llvm::None);
scf::ReduceOp scf_reduce_op =
CreateReduceOpInNestedParallelLoops(reduce_op, &rewriter);
ConvertToReductionOperator(reduce_op.getLoc(), scf_reduce_op,
&reduce_op.body().front(), &rewriter);
rewriter.replaceOp(reduce_op, llvm::None);
return success();
}
private:
// Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
// refers to the parallel dimensions of `xla_reduce_op` if any and the inner
// refers to the parallel dimensions of `reduce_op` if any and the inner
// ParallelOp refers to the reduction dimensions. The scf.reduce op is
// returned.
//
@ -226,16 +226,15 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
// scf.yield
// }
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
lmhlo::ReduceOp xla_reduce_op,
ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_op.getLoc();
lmhlo::ReduceOp reduce_op, ConversionPatternRewriter* rewriter) const {
auto loc = reduce_op.getLoc();
DenseSet<int> reducing_dims;
for (const auto& rdim : xla_reduce_op.dimensions().getIntValues()) {
for (const auto& rdim : reduce_op.dimensions().getIntValues()) {
reducing_dims.insert(rdim.getSExtValue());
}
Value operand = *xla_reduce_op.operands().begin();
Value out = *xla_reduce_op.out().begin();
Value operand = *reduce_op.operands().begin();
Value out = *reduce_op.out().begin();
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
auto operand_shape = operand.getType().cast<MemRefType>().getShape();
@ -252,7 +251,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
}
// Load initial value from memref<element_type>.
SmallVector<Value, 1> init_value = {
rewriter->create<LoadOp>(loc, *xla_reduce_op.init_values().begin())};
rewriter->create<LoadOp>(loc, *reduce_op.init_values().begin())};
// Outer ParallelOp is not needed if it is a reduction across all dims.
scf::ParallelOp outer;
if (!parallel_lower.empty()) {
@ -293,7 +292,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
rewriter->setInsertionPointToStart(inner.getBody());
Value elem = rewriter->create<mlir::LoadOp>(
loc, *xla_reduce_op.operands().begin(), indices);
loc, *reduce_op.operands().begin(), indices);
return rewriter->create<scf::ReduceOp>(loc, elem);
}
};
@ -364,42 +363,42 @@ class ReduceWindowOpConverter
using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
lmhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
lmhlo::ReduceWindowOp reduce_window_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
scf::ParallelOp output_loop, window_loop;
std::tie(output_loop, window_loop) =
CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op,
CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op,
&rewriter);
scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
xla_reduce_window_op, output_loop, window_loop, &rewriter);
reduce_window_op, output_loop, window_loop, &rewriter);
ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op,
&xla_reduce_window_op.body().front(), &rewriter);
rewriter.replaceOp(xla_reduce_window_op, llvm::None);
ConvertToReductionOperator(reduce_window_op.getLoc(), reduce_op,
&reduce_window_op.body().front(), &rewriter);
rewriter.replaceOp(reduce_window_op, llvm::None);
return success();
}
private:
std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow(
lmhlo::ReduceWindowOp xla_reduce_window_op,
lmhlo::ReduceWindowOp reduce_window_op,
ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_window_op.getLoc();
auto loc = reduce_window_op.getLoc();
Value init_value =
rewriter->create<LoadOp>(loc, xla_reduce_window_op.init_value());
rewriter->create<LoadOp>(loc, reduce_window_op.init_value());
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
// Create an outer parallel loop that spans the output of ReduceWindowOp.
Value xla_output = xla_reduce_window_op.out();
auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter);
Value output = reduce_window_op.out();
auto output_loop = MakeLoopOverShape(loc, output, rewriter);
// Create a nested loop that traverses the window.
SmallVector<Value, 2> window_lower, window_upper, window_step;
rewriter->setInsertionPointToStart(output_loop.getBody());
for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) {
for (const auto& window_dim : reduce_window_op.window_dimensions()) {
window_step.push_back(one);
window_lower.push_back(zero);
window_upper.push_back(
@ -410,38 +409,38 @@ class ReduceWindowOpConverter
Value reduction_result = *window_loop.getResults().begin();
auto output_ivs = output_loop.getInductionVars();
rewriter->create<StoreOp>(loc, reduction_result, xla_output, output_ivs);
rewriter->create<StoreOp>(loc, reduction_result, output, output_ivs);
return std::make_pair(output_loop, window_loop);
}
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
lmhlo::ReduceWindowOp xla_reduce_window_op, scf::ParallelOp output_loop,
lmhlo::ReduceWindowOp reduce_window_op, scf::ParallelOp output_loop,
scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
rewriter->setInsertionPointToStart(window_loop.getBody());
auto loc = xla_reduce_window_op.getLoc();
auto loc = reduce_window_op.getLoc();
if (xla_reduce_window_op.base_dilations().hasValue() ||
xla_reduce_window_op.window_dilations().hasValue()) {
xla_reduce_window_op.emitRemark(
if (reduce_window_op.base_dilations().hasValue() ||
reduce_window_op.window_dilations().hasValue()) {
reduce_window_op.emitRemark(
"Lowering to parallel loops does not support `base_dilations` or "
"`window_dilations` attributes yet. The attributes will be ignored.");
}
Value xla_operand = xla_reduce_window_op.operand();
auto xla_operand_type = xla_operand.getType().cast<MemRefType>();
Value operand = reduce_window_op.operand();
auto operand_type = operand.getType().cast<MemRefType>();
// Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
MappedIvs mapped_ivs = MapWindowIvsToInput(
xla_reduce_window_op, output_loop.getInductionVars(),
MappedIvs mapped_ivs =
MapWindowIvsToInput(reduce_window_op, output_loop.getInductionVars(),
window_loop.getInductionVars(), rewriter);
auto elem_or_init = rewriter->create<scf::IfOp>(
loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds,
loc, operand_type.getElementType(), mapped_ivs.in_bounds,
/*withElseRegion=*/true);
OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
Value elem = then_builder.create<mlir::LoadOp>(
loc, xla_reduce_window_op.operand(), mapped_ivs.ivs);
loc, reduce_window_op.operand(), mapped_ivs.ivs);
then_builder.create<scf::YieldOp>(loc, elem);
OpBuilder else_builder = elem_or_init.getElseBodyBuilder();

View File

@ -45,13 +45,13 @@ class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> {
public:
explicit LowerComplex() : PassWrapper<LowerComplex, FunctionPass>() {}
/// Performs the lowering to XLA dialect.
/// Performs the lowering to MHLO dialect.
void runOnFunction() override;
};
} // end anonymous namespace
namespace mlir {
namespace xla {
namespace hlo {
namespace {
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc"
@ -62,18 +62,18 @@ void PopulateComplexLoweringPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) {
populateWithGenerated(context, patterns);
}
} // end namespace xla
} // end namespace hlo
} // end namespace mlir
// Lowers the complex operations that can be represented using other operations.
void LowerComplex::runOnFunction() {
// Add lowering patterns to the list.
OwningRewritePatternList patterns;
mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns);
mlir::hlo::PopulateComplexLoweringPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
static PassRegistration<LowerComplex> pass(
"test-xla-lower-complex",
"mhlo-test-lower-complex",
"Lower complex operations into non-complex operations");

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering XLA general dot to a regular dot.
// This file implements logic for lowering MHLO general dot to a regular dot.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
@ -188,5 +188,5 @@ void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
}
static PassRegistration<LegalizeGeneralDot> legalize_pass(
"test-xla-lower-general-dot",
"mhlo-test-lower-general-dot",
"Tests lowering general dot to a non-batched dot when possible");

View File

@ -54,5 +54,5 @@ struct TestMaterializeBroadcastsPass
} // namespace mlir
static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass(
"test-xla-materialize-broadcasts",
"mhlo-test-materialize-broadcasts",
"Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -479,7 +479,7 @@ class FusionPlanner {
EquivalenceClasses<int32_t> leader_for_node_;
};
struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
struct MhloFusion : public mlir::PassWrapper<MhloFusion, FunctionPass> {
void runOnFunction() override {
FuncOp func = getFunction();
if (!IsTargetFunc(func)) {
@ -568,12 +568,12 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
return std::make_unique<XlaHloFusion>();
std::unique_ptr<OperationPass<FuncOp>> createMhloFusion() {
return std::make_unique<MhloFusion>();
}
static PassRegistration<XlaHloFusion> mhlo_fusion_pass(
"xla-hlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
static PassRegistration<MhloFusion> mhlo_fusion_pass(
"mhlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
} // namespace mhlo
} // namespace mlir

View File

@ -71,7 +71,7 @@ class SinkConstantsToControlFlow
};
static mlir::PassRegistration<SinkConstantsToControlFlow> pass(
"xla-hlo-sink-constants-to-control-flow",
"mhlo-sink-constants-to-control-flow",
"Sink constants implicitly captured in control flow regions. This is "
"necessary to export to XLA.");

View File

@ -22,12 +22,12 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
namespace mlir {
namespace xla {
namespace hlo {
namespace {
struct InferReturnTypeComponentsPattern : public RewritePattern {
InferReturnTypeComponentsPattern(MLIRContext *context)
: RewritePattern("xla_test.get_return_type_components", 1, context) {}
: RewritePattern("mhlo_test.get_return_type_components", 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1) return failure();
@ -44,7 +44,7 @@ struct InferReturnTypeComponentsPattern : public RewritePattern {
}
// Replace the op with another pass-through op with attributes added.
OperationState state(op->getLoc(), "xla_test.return_type_components",
OperationState state(op->getLoc(), "mhlo_test.return_type_components",
op->getOperands(), op->getResultTypes(),
op->getAttrs());
auto new_op = rewriter.createOperation(state);
@ -65,7 +65,7 @@ struct InferReturnTypeComponentsPattern : public RewritePattern {
struct ReifyReturnTypeShapesPattern : public RewritePattern {
ReifyReturnTypeShapesPattern(MLIRContext *context)
: RewritePattern("xla_test.reify_return_type_shapes", 1, context) {}
: RewritePattern("mhlo_test.reify_return_type_shapes", 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1) return failure();
@ -92,9 +92,9 @@ struct TestInferShapedTypeMethodsPass
};
} // namespace
} // namespace xla
} // namespace hlo
} // namespace mlir
static mlir::PassRegistration<mlir::xla::TestInferShapedTypeMethodsPass> pass(
"test-xla-infer-shaped-type-methods",
static mlir::PassRegistration<mlir::hlo::TestInferShapedTypeMethodsPass> pass(
"mhlo-test-infer-shaped-type-methods",
"Uses test ops to invoke InferShapedTypeOpInterface methods");

View File

@ -42,5 +42,5 @@ struct TestUnfuseBatchNormPass
} // namespace mlir
static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass(
"test-xla-unfuse-batch-norm",
"mhlo-test-unfuse-batch-norm",
"Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
namespace mlir {
namespace xla {
namespace hlo {
bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
DenseIntElementsAttr broadcast_dims) {
@ -70,5 +70,5 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
result_shape_v);
}
} // namespace xla
} // namespace hlo
} // namespace mlir

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
namespace mlir {
namespace xla {
namespace hlo {
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type) {
@ -82,5 +82,5 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
}));
}
} // namespace xla
} // namespace hlo
} // namespace mlir

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
namespace mlir {
namespace xla {
namespace hlo {
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y,
bool allow_empty) {
@ -66,5 +66,5 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
return DenseElementsAttr::get(scalar_ty, value);
}
} // namespace xla
} // namespace hlo
} // namespace mlir

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -mhlo-test-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// CHECK-LABEL: @broadcast_add
// Note that all broadcast_ops are expanded from the same template, so
@ -12,7 +12,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
// CHECK: return %[[EXTENTS]]
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
%1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
return %1 : tensor<1xindex>
}
@ -20,8 +20,8 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
// CHECK-LABEL: @complex_ranked_components
func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
// CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
%1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
return %1 : tensor<?x?xcomplex<f32>>
}
@ -29,8 +29,8 @@ func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
// CHECK-LABEL: @compare_ranked_components
func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
%1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
return %0 : tensor<?x?xi1>
}
@ -38,8 +38,8 @@ func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
// CHECK-LABEL: @broadcast_add_ranked_components_r1
func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
%1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
return %1 : tensor<?xf32>
}
@ -49,8 +49,8 @@ func @broadcast_add_ranked_components_r1x2(%arg0: tensor<?xf32>, %arg1: tensor<?
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
// TODO: Overly broad shapes are being returned. Tighten the calculation
// and update/extend these tests.
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x3xf32>) -> tensor<?x3xf32>
// CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}
%1 = "mhlo_test.get_return_type_components"(%0) : (tensor<?x3xf32>) -> tensor<?x3xf32>
return %1 : tensor<?x3xf32>
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a
// representative op for detailed broadcast semantics.

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -xla-legalize-control-flow %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -mhlo-legalize-control-flow %s -o - | FileCheck %s
// CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> {
func @while(%arg0: tensor<i64>) -> tensor<i64> {

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -xla-legalize-to-std %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -mhlo-legalize-to-std %s -o - | FileCheck %s
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
// RUN: mlir-hlo-opt -mhlo-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
func @tanh_f64(%arg0 : f64) -> f64 {
%res = tanh %arg0 : f64

View File

@ -1,6 +1,6 @@
// GenericAtomicRMWOp should contain only ops with no side effects.
// Unfortunately, the legalization pattern for SelectAndScatterOp has to adapt
// to XLA LHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body.
// to LMHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body.
// Lowering to STD dialect and store forwarding pass would be required to get
// rid of them. This is exactly what is done in the real MLIR GPU pipeline, but
// here we disable verification with `verify-each=0` to check the output IR.

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
// CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -mhlo-test-lower-general-dot -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: @testDebatch1
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -mhlo-test-materialize-broadcasts -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: @clampBroadcast
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s
// RUN: mlir-hlo-opt %s -mhlo-fusion -split-input-file | FileCheck %s
// CHECK-LABEL: func @multi_outputs_same
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s
// RUN: mlir-hlo-opt %s -mhlo-sink-constants-to-control-flow | FileCheck %s
// Tests sinking constants to a while loop.

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s
// RUN: mlir-hlo-opt -split-input-file -mhlo-test-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s
// CHECK-LABEL: @batchNormInference_2D_inner_features
// CHECK-SAME: %[[X:[^:[:space:]]+]]