diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 818f11a..e93bf91 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -38,7 +38,7 @@ def HLOClient_Dialect : Dialect { let name = "chlo"; let cppNamespace = "chlo"; let summary = [{ - XLA Client HLO Ops + Client HLO Ops }]; let description = [{ @@ -60,7 +60,7 @@ class HLOClient_Op traits> : } //===----------------------------------------------------------------------===// -// XLA binary elementwise op definitions. +// CHLO binary elementwise op definitions. // From the client perspective, each of these support both explicit rank // broadcasting (via the broadcast_dimensions attribute) and implicit degenerate // shape broadcasting. diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index e57aa60..976b06f 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines the operations used in the XLA dialect. +// This file defines the operations used in the MHLO dialect. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index c76165a..c714ef2 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -13,10 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This is the operation definition file for XLA HLO ops which map to the -// traditional definition in xla_data.proto (or are aligned with the goals -// thereof). -// See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto +// This is the operation definition file for MHLO ops. #ifndef HLO_OPS #define HLO_OPS @@ -44,7 +41,7 @@ class HLO_Op traits> : } //===----------------------------------------------------------------------===// -// XLA nullary op definitions. +// MHLO nullary op definitions. //===----------------------------------------------------------------------===// def HLO_ConstOp : HLO_Op<"constant", @@ -113,7 +110,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { } //===----------------------------------------------------------------------===// -// XLA unary elementwise op definitions. +// MHLO unary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions @@ -264,7 +261,7 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", HLO_FpOrComplexTensor>, BASE_HLO_TanhOp; //===----------------------------------------------------------------------===// -// XLA binary elementwise op definitions. +// MHLO binary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations @@ -363,7 +360,7 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", } //===----------------------------------------------------------------------===// -// XLA binary logical elementwise op definitions. +// MHLO binary logical elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations @@ -381,7 +378,7 @@ def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp; def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp; //===----------------------------------------------------------------------===// -// XLA communication op definitions. +// MHLO communication op definitions. //===----------------------------------------------------------------------===// // InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'. @@ -481,7 +478,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> { } //===----------------------------------------------------------------------===// -// XLA parallelism related op definitions. +// MHLO parallelism related op definitions. //===----------------------------------------------------------------------===// def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, @@ -492,7 +489,7 @@ def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, } //===----------------------------------------------------------------------===// -// XLA control flow op definitions. +// MHLO control flow op definitions. //===----------------------------------------------------------------------===// def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> { @@ -640,7 +637,7 @@ def HLO_ReduceOp: HLO_Op<"reduce", [ } //===----------------------------------------------------------------------===// -// XLA tuple op definitions. +// MHLO tuple op definitions. //===----------------------------------------------------------------------===// def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO_GetTupleElementOp { let arguments = (ins @@ -684,7 +681,7 @@ def HLO_CompareOp: HLO_Op<"compare", } //===----------------------------------------------------------------------===// -// XLA Slice definitions. +// MHLO Slice definitions. //===----------------------------------------------------------------------===// def HLO_SliceOp: HLO_Op< @@ -745,7 +742,7 @@ def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice", //===----------------------------------------------------------------------===// -// XLA Other op definitions. +// MHLO Other op definitions. //===----------------------------------------------------------------------===// def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>, @@ -1320,7 +1317,7 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> { } //===----------------------------------------------------------------------===// -// XLA RngUniform Operator. +// MHLO RngUniform Operator. //===----------------------------------------------------------------------===// def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp { let arguments = (ins @@ -1347,7 +1344,7 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { } //===----------------------------------------------------------------------===// -// XLA Quantize Operator. +// MHLO Quantize Operator. //===----------------------------------------------------------------------===// def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>, BASE_HLO_DequantizeOp { diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 98a0de3..cf90fc2 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -35,7 +35,7 @@ def HLO_Complex : Complex>; defvar BroadcastDimAttr = I64ElementsAttr; //===----------------------------------------------------------------------===// -// XLA on tensors type definitions. +// MHLO on tensors type definitions. //===----------------------------------------------------------------------===// // Token type. @@ -78,7 +78,7 @@ def HLO_StaticShapeTensor : StaticShapeTensorOf<[ AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; //===----------------------------------------------------------------------===// -// XLA on tensors combined type definitions. +// MHLO on tensors combined type definitions. //===----------------------------------------------------------------------===// // Any integer or floating-point tensor types @@ -97,7 +97,7 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>; def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; //===----------------------------------------------------------------------===// -// XLA nullary op definitions. +// MHLO nullary op definitions. //===----------------------------------------------------------------------===// class BASE_HLO_ConstOp { @@ -117,7 +117,7 @@ class BASE_HLO_IotaOp { } //===----------------------------------------------------------------------===// -// XLA unary elementwise op definitions. +// MHLO unary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index fdb301d..7c7b643 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -25,19 +25,19 @@ def NullArrayAttr : NativeCodeCall<"ArrayAttr()">; def CastIntElementsAttr : NativeCodeCall<"$0.cast()">; class ConstantSplat : NativeCodeCall< - "xla::getSplat(&$_builder, $0, " # value # ")">; + "hlo::getSplat(&$_builder, $0, " # value # ")">; def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< - "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">; + "hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1)">; def BinBroadcastDimensionsNonEmpty : NativeCodeCall< - "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">; + "hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">; // Here, the element type can be any integer or float type. But, note that only // 32 bit integers are supported for the value. class GetScalarOfType : NativeCodeCall< - "xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; + "hlo::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; #endif // HLO_UTILS diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 167df89..e8fcae3 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This is the operation definition file for LXLA. +// This is the operation definition file for LMHLO, the "late" MHLO variant of +// the dialect, which operates on buffers instead of tensors. // -// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to +// This file largely overlaps with mhlo_ops.td at a logic level. It's tempting to // merge these two files together, but we need to consider the following // obstacles: // * We need to have a common representation for arguments. That is to say, @@ -43,7 +44,7 @@ def LHLO_Dialect : Dialect { } //===----------------------------------------------------------------------===// -// XLA type definitions. +// LMHLO type definitions. //===----------------------------------------------------------------------===// // Any integer tensor types @@ -66,7 +67,7 @@ def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; //===----------------------------------------------------------------------===// -// XLA nullary op definitions. +// LMHLO nullary op definitions. //===----------------------------------------------------------------------===// class LHLO_Op traits> : @@ -86,7 +87,7 @@ def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp { } //===----------------------------------------------------------------------===// -// XLA unary elementwise op definitions. +// LMHLO unary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions @@ -157,7 +158,7 @@ def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HL def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp; //===----------------------------------------------------------------------===// -// XLA binary elementwise op definitions. +// LMHLO binary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations @@ -212,7 +213,7 @@ def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp; def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp; //===----------------------------------------------------------------------===// -// XLA control flow op definitions. +// LMHLO control flow op definitions. //===----------------------------------------------------------------------===// // TODO(b/139813999): specify required function signature in a type-safe way. @@ -284,7 +285,7 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>, } //===----------------------------------------------------------------------===// -// XLA tuple op definitions. +// LMHLO tuple op definitions. //===----------------------------------------------------------------------===// def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { @@ -298,7 +299,7 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { } //===----------------------------------------------------------------------===// -// XLA Slice definitions. +// LMHLO Slice definitions. //===----------------------------------------------------------------------===// def LHLO_SliceOp: LHLO_Op< @@ -483,7 +484,7 @@ def ReshapeMemRefCastOp: Op, diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h similarity index 91% rename from include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h rename to include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index be06237..fbcb21a 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ -#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" @@ -150,15 +150,14 @@ inline Value MapLhloOpToStdScalarOp(Location loc, } template -inline Optional getCmpPredicate( - StringRef xla_comparison_direction) { +inline Optional getCmpPredicate(StringRef comparison_direction) { return llvm::None; } template <> inline Optional getCmpPredicate( - StringRef xla_comparison_direction) { - return llvm::StringSwitch>(xla_comparison_direction) + StringRef comparison_direction) { + return llvm::StringSwitch>(comparison_direction) .Case("EQ", CmpFPredicate::OEQ) .Case("NE", CmpFPredicate::ONE) .Case("GE", CmpFPredicate::OGE) @@ -170,8 +169,8 @@ inline Optional getCmpPredicate( template <> inline Optional getCmpPredicate( - StringRef xla_comparison_direction) { - return llvm::StringSwitch>(xla_comparison_direction) + StringRef comparison_direction) { + return llvm::StringSwitch>(comparison_direction) .Case("EQ", CmpIPredicate::eq) .Case("NE", CmpIPredicate::ne) .Case("GE", CmpIPredicate::sge) @@ -181,11 +180,11 @@ inline Optional getCmpPredicate( .Default(llvm::None); } -template -inline Value MapXlaCompareOpToStdScalarOp(Location loc, - StringRef comparison_direction, - ArrayRef result_types, - ArrayRef args, OpBuilder* b) { +template +inline Value MapCompareOpToStdScalarOp(Location loc, + StringRef comparison_direction, + ArrayRef result_types, + ArrayRef args, OpBuilder* b) { const auto& lhs = args[0]; const auto& rhs = args[1]; Type element_type = lhs.getType(); @@ -193,15 +192,15 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc, Optional predicate = getCmpPredicate(comparison_direction); assert(predicate.hasValue() && "expected valid comparison direction"); - return b->create>(loc, predicate.getValue(), lhs, - rhs); + return b->create>(loc, predicate.getValue(), lhs, + rhs); } if (element_type.isa()) { Optional predicate = getCmpPredicate(comparison_direction); assert(predicate.hasValue() && "expected valid comparison direction"); - return b->create>(loc, predicate.getValue(), lhs, - rhs); + return b->create>(loc, predicate.getValue(), lhs, + rhs); } return nullptr; } @@ -337,10 +336,10 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } -/// Implements the conversion of XLA op to scalar op (to use within region of a +/// Implements the conversion of HLO op to scalar op (to use within region of a /// linalg.generic op) for compare-select style operations like min/max. template -struct XlaCompareSelectOpToStdScalarOp { +struct CompareSelectOpToStdScalarOp { static Value map(Location loc, StringRef comparison_direction, ArrayRef result_types, ArrayRef args, OpBuilder* b) { @@ -352,8 +351,8 @@ struct XlaCompareSelectOpToStdScalarOp { /// dialect with a given predicate based on the element type of the operand. template -struct XlaCompareSelectOpToStdScalarOp { +struct CompareSelectOpToStdScalarOp { static Value map(Location loc, StringRef comparison_direction, ArrayRef result_types, ArrayRef args, OpBuilder* b) { @@ -365,8 +364,8 @@ struct XlaCompareSelectOpToStdScalarOpcreate<::mlir::SelectOp>(loc, cmp, args[0], args[1]); } - return XlaCompareSelectOpToStdScalarOp::map( - loc, comparison_direction, result_types, args, b); + return CompareSelectOpToStdScalarOp::map(loc, comparison_direction, + result_types, args, b); } }; @@ -384,7 +383,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return XlaCompareSelectOpToStdScalarOp< + return CompareSelectOpToStdScalarOp< IntegerType, ScalarIOp, CmpIPredicate, FloatType, ScalarFOp, CmpFPredicate>::map(loc, "GT", result_types, args, b); @@ -395,7 +394,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return XlaCompareSelectOpToStdScalarOp< + return CompareSelectOpToStdScalarOp< IntegerType, ScalarIOp, CmpIPredicate, FloatType, ScalarFOp, CmpFPredicate>::map(loc, "LT", result_types, args, b); @@ -475,25 +474,25 @@ inline Value MapLhloOpToStdScalarOp(Location loc, } // namespace impl -struct XlaOpToStdScalarOp { +struct HloOpToStdScalarOp { // Implementation for LHLO ops except lmhlo::CompareOp. - template ::value && std::is_same, std::false_type>::value>> - static Value map(XlaOpTy op, ArrayRef result_types, + static Value map(HloOpTy op, ArrayRef result_types, ArrayRef args, OpBuilder* b, unsigned i = 0) { return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, args, b); } // Implementation for HLO ops except mhlo::CompareOp. - template , + template , typename = std::enable_if_t< !std::is_same::value && !std::is_same::value>> - static Value map(XlaOpTy op, ArrayRef result_types, + static Value map(HloOpTy op, ArrayRef result_types, ArrayRef args, OpBuilder* b, int i = 0) { return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, args, b); @@ -505,7 +504,7 @@ struct XlaOpToStdScalarOp { static Value map(lmhlo::CompareOp op, ArrayRef result_types, ArrayRef args, OpBuilder* b) { auto comparison_direction = op.comparison_direction(); - return impl::MapXlaCompareOpToStdScalarOp( + return impl::MapCompareOpToStdScalarOp( op.getLoc(), comparison_direction, result_types, args, b); } @@ -516,7 +515,7 @@ struct XlaOpToStdScalarOp { static Value map(mhlo::CompareOp op, ArrayRef result_types, ArrayRef args, OpBuilder* b) { auto comparison_direction = op.comparison_direction(); - return impl::MapXlaCompareOpToStdScalarOp( + return impl::MapCompareOpToStdScalarOp( op.getLoc(), comparison_direction, result_types, args, b); } }; @@ -524,4 +523,4 @@ struct XlaOpToStdScalarOp { } // namespace lmhlo } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index aa06493..8a52578 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -56,7 +56,7 @@ std::unique_ptr> createTransformUnrankedHloPass(); std::unique_ptr> createSinkConstantsToControlFlowPass(); // fuse mhlo ops to kLoop/kInput fusion patterns -std::unique_ptr> createXlaHloFusionPass(); +std::unique_ptr> createMhloFusionPass(); } // namespace mhlo @@ -94,12 +94,12 @@ std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); } // namespace lmhlo -namespace xla { +namespace hlo { /// Lowers the standard TanhOp to an approximation that does not use intrinsics. std::unique_ptr> createLegalizeTanhToApproximationPass(); -} // namespace xla +} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 42bc719..1e99ce0 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -38,8 +38,8 @@ void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns, void PopulateComplexLoweringPatterns(MLIRContext *context, OwningRewritePatternList *patterns); -void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, - MLIRContext *ctx); +void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, + MLIRContext *ctx); // Collection of rewrite patterns for lowering of HLO to LHLO dialect. void populateHLOToLHLOConversionPattern( @@ -93,14 +93,14 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, } // namespace chlo -namespace xla { +namespace hlo { // Populates a pattern that translates the standard TanhOp to an approximation // that does not use intrinsics. void PopulateTanhToApproximationPatterns(MLIRContext *context, OwningRewritePatternList *patterns); -} // namespace xla +} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_ diff --git a/include/mlir-hlo/utils/broadcast_utils.h b/include/mlir-hlo/utils/broadcast_utils.h index 957d5f9..85e5ffe 100644 --- a/include/mlir-hlo/utils/broadcast_utils.h +++ b/include/mlir-hlo/utils/broadcast_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_ // Utilities relating to implementing HLO broadcasting. // Note: This file should not depend on any non-MLIR TensorFlow libraries. @@ -27,7 +27,7 @@ limitations under the License. #include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" namespace mlir { -namespace xla { +namespace hlo { // Checks whether the given operand types and broadcast_dims attr represent a // legal combination for "numpy" style broadcasting (where 1-dims are prepended @@ -43,7 +43,7 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, Value rhs, OpBuilder& builder); -} // namespace xla +} // namespace hlo } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_BROADCAST_UTILS_H_ diff --git a/include/mlir-hlo/utils/convert_op_folder.h b/include/mlir-hlo/utils/convert_op_folder.h index dcda285..9a62a03 100644 --- a/include/mlir-hlo/utils/convert_op_folder.h +++ b/include/mlir-hlo/utils/convert_op_folder.h @@ -13,21 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" namespace mlir { -namespace xla { +namespace hlo { // Converts the given elements attr to the specified elements type. // Requires type of the elements and new_type to be either integer or float // type. mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, mlir::Type new_type); -} // namespace xla +} // namespace hlo } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ diff --git a/include/mlir-hlo/utils/cycle_detector.h b/include/mlir-hlo/utils/cycle_detector.h index eea0f25..0cec777 100644 --- a/include/mlir-hlo/utils/cycle_detector.h +++ b/include/mlir-hlo/utils/cycle_detector.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_ #include @@ -162,4 +162,4 @@ class GraphCycles { } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_ diff --git a/include/mlir-hlo/utils/hlo_utils.h b/include/mlir-hlo/utils/hlo_utils.h index cfb7184..de9d9b3 100644 --- a/include/mlir-hlo/utils/hlo_utils.h +++ b/include/mlir-hlo/utils/hlo_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" @@ -23,7 +23,7 @@ limitations under the License. #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" namespace mlir { -namespace xla { +namespace hlo { // Computes the broadcast dimensions attr for an elementwise binary operator // between two ranked tensors. @@ -68,7 +68,7 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) { // Requires `ty` to be either FloatType of IntegerType. DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value); -} // namespace xla +} // namespace hlo } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 1917723..63a7cda 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -137,7 +137,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( auto broadcast_dimensions = op->getAttr("broadcast_dimensions") .dyn_cast_or_null(); if (broadcast_dimensions && - !xla::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) { + !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) { // Note: It is unclear whether the general specification of explicit // broadcast_dimensions on binary ops is a feature we want to carry // forward. While it can technically be implemented for ranked-dynamic, @@ -150,7 +150,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( << "broadcast_dimensions = " << broadcast_dimensions; } - Value computed_shape = xla::ComputeBinaryElementwiseBroadcastingResultExtents( + Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents( loc, lhs, rhs, builder); if (!computed_shape) return failure(); reifiedReturnShapes.push_back(computed_shape); diff --git a/lib/Dialect/mhlo/IR/dialect_registration.cc b/lib/Dialect/mhlo/IR/dialect_registration.cc index 44c68b8..0608341 100644 --- a/lib/Dialect/mhlo/IR/dialect_registration.cc +++ b/lib/Dialect/mhlo/IR/dialect_registration.cc @@ -17,7 +17,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -// Static initialization for XLA dialect registration. +// Static initialization for *HLO dialects registration. static mlir::DialectRegistration mhlo_ops; static mlir::DialectRegistration chlo_ops; static mlir::DialectRegistration lmhlo_ops; diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 645bffc..5cc353b 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines the operations used in the XLA dialect. +// This file defines the operations used in the MHLO dialect. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" @@ -404,7 +404,7 @@ OpFoldResult ConvertOp::fold(ArrayRef operands) { // If the operand is constant, we can do the conversion now. if (auto elementsAttr = operands.front().dyn_cast_or_null()) { - return xla::ConvertElementsAttr(elementsAttr, + return hlo::ConvertElementsAttr(elementsAttr, getElementTypeOrSelf(getResult())); } @@ -2135,8 +2135,6 @@ MhloDialect::MhloDialect(MLIRContext* context) >(); addInterfaces(); addTypes(); - // Support unknown operations because not all XLA operations are registered. - // allowUnknownOperations(); } Type MhloDialect::parseType(DialectAsmParser& parser) const { diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 0c60c6f..57d3b78 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines the operations used in the XLA dialect. +// This file defines the operations used in the LMHLO dialect. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 488c25f..56d8130 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -96,7 +96,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp // Check for "numpy"-style rank broadcast. auto broadcast_dimensions = op.broadcast_dimensions(); if (broadcast_dimensions && - !xla::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) { + !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) { // Note: It is unclear whether the general specification of explicit // broadcast_dimensions on binary ops is a feature we want to carry // forward. While it can technically be implemented for ranked-dynamic, @@ -126,7 +126,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); Value result_extents = - xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, + hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, rewriter); // Note that we unconditionally emit DynamicBroadcastInDim ops and let diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index e006072..c7ec85f 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -53,5 +53,5 @@ struct TestChloLegalizeToHloPass } // namespace mlir static mlir::PassRegistration pass( - "test-xla-chlo-legalize-to-hlo", + "mhlo-test-chlo-legalize-to-hlo", "Test pass for applying chlo -> hlo legalization patterns"); diff --git a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc index 83cd9c4..766d3c3 100644 --- a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file implements logic for lowering XLA dialect to Standard dialect. +// This file implements logic for lowering MHLO dialect to Standard dialect. #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" @@ -107,8 +107,8 @@ LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) { } LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { - // Converts an XLA while loop into control flow. This generates a set of MLIR - // blocks and branches, along with inlining the regions provided by the XLA + // Converts a MHLO while loop into control flow. This generates a set of MLIR + // blocks and branches, along with inlining the regions provided by the MHLO // while loop. The structure should be similar to below: // // @@ -232,5 +232,5 @@ mlir::mhlo::createLegalizeControlFlowPass() { } static PassRegistration legalize_cf_pass( - "xla-legalize-control-flow", - "Legalize from XLA control flow to MLIR control flow"); + "mhlo-legalize-control-flow", + "Legalize from MHLO control flow to CFG control flow"); diff --git a/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc b/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc index fbc5f50..dfd05bb 100644 --- a/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc +++ b/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc @@ -24,7 +24,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla { +namespace hlo { namespace { /// Emits the fast tanh approximation that is also used by XLA. @@ -149,8 +149,8 @@ void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context, } static PassRegistration legalize_pass( - "xla-legalize-tanh-to-approximation", + "mhlo-legalize-tanh-to-approximation", "Legalize tanh from standard dialect to an approximation"); -} // namespace xla +} // namespace hlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc similarity index 97% rename from lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc rename to lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index ccecadd..c0d6e30 100644 --- a/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -32,7 +32,7 @@ limitations under the License. #include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { @@ -49,12 +49,12 @@ Value getResultValue(Operation* op) { } template -ShapedType getXLAOpResultType(Operation* op) { +ShapedType getHloOpResultType(Operation* op) { return getResultValue(op).getType().template cast(); } template -bool verifyXLAOpBufferOrTensorSemantics(Operation* op) { +bool verifyHloOpBufferOrTensorSemantics(Operation* op) { auto verifyType = [&](Value val) -> bool { return (isLHLO && val.getType().isa()) || (!isLHLO && val.getType().isa()); @@ -133,7 +133,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { // TODO(ravishankarm) : For now use the method in lmhlo namespace. // That method needs to be moved out of there. - Value opResult = lmhlo::XlaOpToStdScalarOp::map( + Value opResult = lmhlo::HloOpToStdScalarOp::map( op, bodyResultTypes, llvm::to_vector<2>(args.take_front(args_count)), &rewriter); nestedBuilder.create(loc, opResult); @@ -163,7 +163,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { auto lhs = rewriter.create(loc, lhlo_op.lhs()); auto rhs = rewriter.create(loc, lhlo_op.rhs()); // TODO(ravishankarm) : Move this method out of lmhlo namespace. - Value opResult = lmhlo::XlaOpToStdScalarOp::map( + Value opResult = lmhlo::HloOpToStdScalarOp::map( lhlo_op, argType.getElementType(), llvm::ArrayRef{lhs, rhs}, &rewriter); rewriter.create(loc, opResult, lhlo_op.out()); @@ -274,7 +274,7 @@ struct ConvToLinalgConverter : public OpConversionPattern { } }; -/// Base class for lowering xla operations that have one operand and one result, +/// Base class for lowering HLO operations that have one operand and one result, /// and are semantically equivalent to a copy of the input to the output (like /// transpose, some reshape, etc.). The derived classes need to provide a method /// `getIndexingMaps` that returns AffineMaps for the index maps of the input @@ -287,8 +287,8 @@ class DataMovementOpConverter : public OpConversionPattern { LogicalResult matchAndRewrite( OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - if (!verifyXLAOpBufferOrTensorSemantics(op)) return failure(); - auto resultType = getXLAOpResultType(op); + if (!verifyHloOpBufferOrTensorSemantics(op)) return failure(); + auto resultType = getHloOpResultType(op); SmallVector indexing_maps = Derived::getIndexingMaps(op, &rewriter); @@ -322,7 +322,7 @@ class BroadcastConverter ShapedType inputType = broadcastOp.operand().getType().template cast(); unsigned inputRank = inputType.getRank(); - unsigned nloops = getXLAOpResultType(broadcastOp).getRank(); + unsigned nloops = getHloOpResultType(broadcastOp).getRank(); // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to // the input's dimensions. @@ -356,7 +356,7 @@ class HloBroadcastInDimConverter static SmallVector getIndexingMaps( mhlo::BroadcastInDimOp broadcastOp, Builder* b) { - auto resultType = getXLAOpResultType(broadcastOp); + auto resultType = getHloOpResultType(broadcastOp); auto operandType = broadcastOp.operand().getType().template cast(); unsigned nloops = resultType.getRank(); @@ -555,7 +555,7 @@ class TransposeConverter isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { auto resultType = - getXLAOpResultType(op).template cast(); + getHloOpResultType(op).template cast(); auto nloops = resultType.getRank(); SmallVector inputExprs; inputExprs.resize(resultType.getRank()); @@ -579,11 +579,11 @@ class ReshapeOpConverter : public OpConversionPattern { LogicalResult matchAndRewrite( OpTy reshapeOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - if (!verifyXLAOpBufferOrTensorSemantics(reshapeOp)) + if (!verifyHloOpBufferOrTensorSemantics(reshapeOp)) return failure(); ShapedType operandType = reshapeOp.operand().getType().template cast(); - ShapedType resultType = getXLAOpResultType(reshapeOp); + ShapedType resultType = getHloOpResultType(reshapeOp); if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); @@ -708,7 +708,7 @@ class ReverseConverter isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { auto resultType = - getXLAOpResultType(op).template cast(); + getHloOpResultType(op).template cast(); auto nloops = resultType.getRank(); SmallVector inputExprs; inputExprs.reserve(nloops); diff --git a/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/lib/Dialect/mhlo/transforms/legalize_to_standard.cc index 0e59727..e3104e4 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_standard.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_standard.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file implements logic for lowering XLA dialect to Standard dialect. +// This file implements logic for lowering MHLO dialect to Standard dialect. #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" @@ -187,8 +187,8 @@ std::unique_ptr> createLegalizeToStdPass() { return std::make_unique(); } -void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, - mlir::MLIRContext *ctx) { +void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, + mlir::MLIRContext *ctx) { mlir::populateWithGenerated(ctx, patterns); patterns->insert(ctx); } @@ -196,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, /// Perform the lowering to standard dialect. void LegalizeToStandard::runOnFunction() { OwningRewritePatternList patterns; - mlir::mhlo::PopulateXlaToStdPatterns(&patterns, &getContext()); + mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext()); applyPatternsAndFoldGreedily(getFunction(), patterns); } static PassRegistration legalize_pass( - "xla-legalize-to-std", "Legalize from XLA dialect to standard dialect"); + "mhlo-legalize-to-std", "Legalize from MHLO dialect to standard dialect"); } // end namespace mhlo } // end namespace mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td b/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td index 2d238eb..a1c3fb7 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td +++ b/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This is the legalization pattern definition file for XLA to StandardOps. +// This is the legalization pattern definition file for MHLO to StandardOps. include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td" diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index 7971240..2da3e8a 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -25,7 +25,7 @@ limitations under the License. #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" namespace mlir { namespace lmhlo { @@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern { auto r = builder.create(loc, rhs, rhs_indices); auto result = rewriter.create(loc, op.output(), result_indices); - Value op_result = lmhlo::XlaOpToStdScalarOp::map( + Value op_result = lmhlo::HloOpToStdScalarOp::map( op, element_type, {l, r, result}, &builder); map_status = success(op_result != nullptr); if (failed(map_status)) return; @@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern { ValueRange induction_vars) { auto l = builder.create(loc, lhs, induction_vars); auto r = builder.create(loc, rhs, induction_vars); - Value op_result = lmhlo::XlaOpToStdScalarOp::map( + Value op_result = lmhlo::HloOpToStdScalarOp::map( op, element_type, {l, r}, &builder); map_status = success(op_result != nullptr); if (failed(map_status)) return; diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index dbae1e6..4bee2cc 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -35,7 +35,7 @@ limitations under the License. #include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" namespace mlir { namespace lmhlo { diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index b9f4ad4..3736763 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -192,22 +192,22 @@ class ReduceOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - lmhlo::ReduceOp xla_reduce_op, ArrayRef /*args*/, + lmhlo::ReduceOp reduce_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { // TODO(b/137624192) Implement variadic reduce. - if (xla_reduce_op.out().size() != 1) return failure(); + if (reduce_op.out().size() != 1) return failure(); - scf::ReduceOp reduce_op = - CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter); - ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op, - &xla_reduce_op.body().front(), &rewriter); - rewriter.replaceOp(xla_reduce_op, llvm::None); + scf::ReduceOp scf_reduce_op = + CreateReduceOpInNestedParallelLoops(reduce_op, &rewriter); + ConvertToReductionOperator(reduce_op.getLoc(), scf_reduce_op, + &reduce_op.body().front(), &rewriter); + rewriter.replaceOp(reduce_op, llvm::None); return success(); } private: // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp - // refers to the parallel dimensions of `xla_reduce_op` if any and the inner + // refers to the parallel dimensions of `reduce_op` if any and the inner // ParallelOp refers to the reduction dimensions. The scf.reduce op is // returned. // @@ -226,16 +226,15 @@ class ReduceOpConverter : public OpConversionPattern { // scf.yield // } scf::ReduceOp CreateReduceOpInNestedParallelLoops( - lmhlo::ReduceOp xla_reduce_op, - ConversionPatternRewriter* rewriter) const { - auto loc = xla_reduce_op.getLoc(); + lmhlo::ReduceOp reduce_op, ConversionPatternRewriter* rewriter) const { + auto loc = reduce_op.getLoc(); DenseSet reducing_dims; - for (const auto& rdim : xla_reduce_op.dimensions().getIntValues()) { + for (const auto& rdim : reduce_op.dimensions().getIntValues()) { reducing_dims.insert(rdim.getSExtValue()); } - Value operand = *xla_reduce_op.operands().begin(); - Value out = *xla_reduce_op.out().begin(); + Value operand = *reduce_op.operands().begin(); + Value out = *reduce_op.out().begin(); SmallVector parallel_lower, parallel_upper, parallel_step; SmallVector reduce_lower, reduce_upper, reduce_step; auto operand_shape = operand.getType().cast().getShape(); @@ -252,7 +251,7 @@ class ReduceOpConverter : public OpConversionPattern { } // Load initial value from memref. SmallVector init_value = { - rewriter->create(loc, *xla_reduce_op.init_values().begin())}; + rewriter->create(loc, *reduce_op.init_values().begin())}; // Outer ParallelOp is not needed if it is a reduction across all dims. scf::ParallelOp outer; if (!parallel_lower.empty()) { @@ -293,7 +292,7 @@ class ReduceOpConverter : public OpConversionPattern { rewriter->setInsertionPointToStart(inner.getBody()); Value elem = rewriter->create( - loc, *xla_reduce_op.operands().begin(), indices); + loc, *reduce_op.operands().begin(), indices); return rewriter->create(loc, elem); } }; @@ -364,42 +363,42 @@ class ReduceWindowOpConverter using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - lmhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, + lmhlo::ReduceWindowOp reduce_window_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { scf::ParallelOp output_loop, window_loop; std::tie(output_loop, window_loop) = - CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op, + CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op, &rewriter); scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( - xla_reduce_window_op, output_loop, window_loop, &rewriter); + reduce_window_op, output_loop, window_loop, &rewriter); - ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op, - &xla_reduce_window_op.body().front(), &rewriter); - rewriter.replaceOp(xla_reduce_window_op, llvm::None); + ConvertToReductionOperator(reduce_window_op.getLoc(), reduce_op, + &reduce_window_op.body().front(), &rewriter); + rewriter.replaceOp(reduce_window_op, llvm::None); return success(); } private: std::pair CreateParallelLoopsToTraverseOutputAndWindow( - lmhlo::ReduceWindowOp xla_reduce_window_op, + lmhlo::ReduceWindowOp reduce_window_op, ConversionPatternRewriter* rewriter) const { - auto loc = xla_reduce_window_op.getLoc(); + auto loc = reduce_window_op.getLoc(); Value init_value = - rewriter->create(loc, xla_reduce_window_op.init_value()); + rewriter->create(loc, reduce_window_op.init_value()); Value zero = rewriter->create(loc, 0); Value one = rewriter->create(loc, 1); // Create an outer parallel loop that spans the output of ReduceWindowOp. - Value xla_output = xla_reduce_window_op.out(); - auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter); + Value output = reduce_window_op.out(); + auto output_loop = MakeLoopOverShape(loc, output, rewriter); // Create a nested loop that traverses the window. SmallVector window_lower, window_upper, window_step; rewriter->setInsertionPointToStart(output_loop.getBody()); - for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) { + for (const auto& window_dim : reduce_window_op.window_dimensions()) { window_step.push_back(one); window_lower.push_back(zero); window_upper.push_back( @@ -410,38 +409,38 @@ class ReduceWindowOpConverter Value reduction_result = *window_loop.getResults().begin(); auto output_ivs = output_loop.getInductionVars(); - rewriter->create(loc, reduction_result, xla_output, output_ivs); + rewriter->create(loc, reduction_result, output, output_ivs); return std::make_pair(output_loop, window_loop); } scf::ReduceOp CreateReduceOpInNestedParallelLoops( - lmhlo::ReduceWindowOp xla_reduce_window_op, scf::ParallelOp output_loop, + lmhlo::ReduceWindowOp reduce_window_op, scf::ParallelOp output_loop, scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const { rewriter->setInsertionPointToStart(window_loop.getBody()); - auto loc = xla_reduce_window_op.getLoc(); + auto loc = reduce_window_op.getLoc(); - if (xla_reduce_window_op.base_dilations().hasValue() || - xla_reduce_window_op.window_dilations().hasValue()) { - xla_reduce_window_op.emitRemark( + if (reduce_window_op.base_dilations().hasValue() || + reduce_window_op.window_dilations().hasValue()) { + reduce_window_op.emitRemark( "Lowering to parallel loops does not support `base_dilations` or " "`window_dilations` attributes yet. The attributes will be ignored."); } - Value xla_operand = xla_reduce_window_op.operand(); - auto xla_operand_type = xla_operand.getType().cast(); + Value operand = reduce_window_op.operand(); + auto operand_type = operand.getType().cast(); // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not. - MappedIvs mapped_ivs = MapWindowIvsToInput( - xla_reduce_window_op, output_loop.getInductionVars(), - window_loop.getInductionVars(), rewriter); + MappedIvs mapped_ivs = + MapWindowIvsToInput(reduce_window_op, output_loop.getInductionVars(), + window_loop.getInductionVars(), rewriter); auto elem_or_init = rewriter->create( - loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds, + loc, operand_type.getElementType(), mapped_ivs.in_bounds, /*withElseRegion=*/true); OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); Value elem = then_builder.create( - loc, xla_reduce_window_op.operand(), mapped_ivs.ivs); + loc, reduce_window_op.operand(), mapped_ivs.ivs); then_builder.create(loc, elem); OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); diff --git a/lib/Dialect/mhlo/transforms/lower_complex.cc b/lib/Dialect/mhlo/transforms/lower_complex.cc index 2b85ef4..0bc47fe 100644 --- a/lib/Dialect/mhlo/transforms/lower_complex.cc +++ b/lib/Dialect/mhlo/transforms/lower_complex.cc @@ -45,13 +45,13 @@ class LowerComplex : public PassWrapper { public: explicit LowerComplex() : PassWrapper() {} - /// Performs the lowering to XLA dialect. + /// Performs the lowering to MHLO dialect. void runOnFunction() override; }; } // end anonymous namespace namespace mlir { -namespace xla { +namespace hlo { namespace { #include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc" @@ -62,18 +62,18 @@ void PopulateComplexLoweringPatterns(MLIRContext* context, OwningRewritePatternList* patterns) { populateWithGenerated(context, patterns); } -} // end namespace xla +} // end namespace hlo } // end namespace mlir // Lowers the complex operations that can be represented using other operations. void LowerComplex::runOnFunction() { // Add lowering patterns to the list. OwningRewritePatternList patterns; - mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns); + mlir::hlo::PopulateComplexLoweringPatterns(&getContext(), &patterns); applyPatternsAndFoldGreedily(getFunction(), patterns); } static PassRegistration pass( - "test-xla-lower-complex", + "mhlo-test-lower-complex", "Lower complex operations into non-complex operations"); diff --git a/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/lib/Dialect/mhlo/transforms/lower_general_dot.cc index 40f3314..1b3b1ac 100644 --- a/lib/Dialect/mhlo/transforms/lower_general_dot.cc +++ b/lib/Dialect/mhlo/transforms/lower_general_dot.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file implements logic for lowering XLA general dot to a regular dot. +// This file implements logic for lowering MHLO general dot to a regular dot. #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" @@ -188,5 +188,5 @@ void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns( } static PassRegistration legalize_pass( - "test-xla-lower-general-dot", + "mhlo-test-lower-general-dot", "Tests lowering general dot to a non-batched dot when possible"); diff --git a/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc index b267642..a418e4c 100644 --- a/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc +++ b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc @@ -54,5 +54,5 @@ struct TestMaterializeBroadcastsPass } // namespace mlir static mlir::PassRegistration pass( - "test-xla-materialize-broadcasts", + "mhlo-test-materialize-broadcasts", "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/lib/Dialect/mhlo/transforms/mhlo_fusion.cc b/lib/Dialect/mhlo/transforms/mhlo_fusion.cc index 568bceb..f7158a0 100644 --- a/lib/Dialect/mhlo/transforms/mhlo_fusion.cc +++ b/lib/Dialect/mhlo/transforms/mhlo_fusion.cc @@ -479,7 +479,7 @@ class FusionPlanner { EquivalenceClasses leader_for_node_; }; -struct XlaHloFusion : public mlir::PassWrapper { +struct MhloFusion : public mlir::PassWrapper { void runOnFunction() override { FuncOp func = getFunction(); if (!IsTargetFunc(func)) { @@ -568,12 +568,12 @@ struct XlaHloFusion : public mlir::PassWrapper { } // namespace -std::unique_ptr> createXlaHloFusion() { - return std::make_unique(); +std::unique_ptr> createMhloFusion() { + return std::make_unique(); } -static PassRegistration mhlo_fusion_pass( - "xla-hlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns."); +static PassRegistration mhlo_fusion_pass( + "mhlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns."); } // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc index dd2e663..446bfa9 100644 --- a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -71,7 +71,7 @@ class SinkConstantsToControlFlow }; static mlir::PassRegistration pass( - "xla-hlo-sink-constants-to-control-flow", + "mhlo-sink-constants-to-control-flow", "Sink constants implicitly captured in control flow regions. This is " "necessary to export to XLA."); diff --git a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc index a7362f7..d258457 100644 --- a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc +++ b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc @@ -22,12 +22,12 @@ limitations under the License. #include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" namespace mlir { -namespace xla { +namespace hlo { namespace { struct InferReturnTypeComponentsPattern : public RewritePattern { InferReturnTypeComponentsPattern(MLIRContext *context) - : RewritePattern("xla_test.get_return_type_components", 1, context) {} + : RewritePattern("mhlo_test.get_return_type_components", 1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getNumOperands() != 1) return failure(); @@ -44,7 +44,7 @@ struct InferReturnTypeComponentsPattern : public RewritePattern { } // Replace the op with another pass-through op with attributes added. - OperationState state(op->getLoc(), "xla_test.return_type_components", + OperationState state(op->getLoc(), "mhlo_test.return_type_components", op->getOperands(), op->getResultTypes(), op->getAttrs()); auto new_op = rewriter.createOperation(state); @@ -65,7 +65,7 @@ struct InferReturnTypeComponentsPattern : public RewritePattern { struct ReifyReturnTypeShapesPattern : public RewritePattern { ReifyReturnTypeShapesPattern(MLIRContext *context) - : RewritePattern("xla_test.reify_return_type_shapes", 1, context) {} + : RewritePattern("mhlo_test.reify_return_type_shapes", 1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getNumOperands() != 1) return failure(); @@ -92,9 +92,9 @@ struct TestInferShapedTypeMethodsPass }; } // namespace -} // namespace xla +} // namespace hlo } // namespace mlir -static mlir::PassRegistration pass( - "test-xla-infer-shaped-type-methods", +static mlir::PassRegistration pass( + "mhlo-test-infer-shaped-type-methods", "Uses test ops to invoke InferShapedTypeOpInterface methods"); diff --git a/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc similarity index 100% rename from lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc rename to lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc index 4a5b5fd..33f21d4 100644 --- a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc @@ -42,5 +42,5 @@ struct TestUnfuseBatchNormPass } // namespace mlir static mlir::PassRegistration pass( - "test-xla-unfuse-batch-norm", + "mhlo-test-unfuse-batch-norm", "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/lib/utils/broadcast_utils.cc b/lib/utils/broadcast_utils.cc index 1c1499a..1350725 100644 --- a/lib/utils/broadcast_utils.cc +++ b/lib/utils/broadcast_utils.cc @@ -24,7 +24,7 @@ limitations under the License. #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" namespace mlir { -namespace xla { +namespace hlo { bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, DenseIntElementsAttr broadcast_dims) { @@ -70,5 +70,5 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, result_shape_v); } -} // namespace xla +} // namespace hlo } // namespace mlir diff --git a/lib/utils/convert_op_folder.cc b/lib/utils/convert_op_folder.cc index cf6b56f..5893b37 100644 --- a/lib/utils/convert_op_folder.cc +++ b/lib/utils/convert_op_folder.cc @@ -22,7 +22,7 @@ limitations under the License. #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" namespace mlir { -namespace xla { +namespace hlo { mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, mlir::Type new_type) { @@ -82,5 +82,5 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, })); } -} // namespace xla +} // namespace hlo } // namespace mlir diff --git a/lib/utils/hlo_utils.cc b/lib/utils/hlo_utils.cc index 75c01a9..4528474 100644 --- a/lib/utils/hlo_utils.cc +++ b/lib/utils/hlo_utils.cc @@ -20,7 +20,7 @@ limitations under the License. #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" namespace mlir { -namespace xla { +namespace hlo { DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y, bool allow_empty) { @@ -66,5 +66,5 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { return DenseElementsAttr::get(scalar_ty, value); } -} // namespace xla +} // namespace hlo } // namespace mlir diff --git a/tests/chlo_infer_shape_type_methods.mlir b/tests/chlo_infer_shape_type_methods.mlir index ab50780..6507432 100644 --- a/tests/chlo_infer_shape_type_methods.mlir +++ b/tests/chlo_infer_shape_type_methods.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-test-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s // CHECK-LABEL: @broadcast_add // Note that all broadcast_ops are expanded from the same template, so @@ -12,7 +12,7 @@ func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xinde // CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]] // CHECK: return %[[EXTENTS]] %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor - %1 = "xla_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> + %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> return %1 : tensor<1xindex> } @@ -20,8 +20,8 @@ func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xinde // CHECK-LABEL: @complex_ranked_components func @complex_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor> { %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> - // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex} - %1 = "xla_test.get_return_type_components"(%0) : (tensor>) -> tensor> + // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex} + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor>) -> tensor> return %1 : tensor> } @@ -29,8 +29,8 @@ func @complex_ranked_components(%arg0: tensor, %arg1: tensor) -> // CHECK-LABEL: @compare_ranked_components func @compare_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor { %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor - // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1} - %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1} + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor) -> tensor return %0 : tensor } @@ -38,8 +38,8 @@ func @compare_ranked_components(%arg0: tensor, %arg1: tensor) -> // CHECK-LABEL: @broadcast_add_ranked_components_r1 func @broadcast_add_ranked_components_r1(%arg0: tensor, %arg1: tensor) -> tensor { %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor - // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32} - %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32} + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor) -> tensor return %1 : tensor } @@ -49,8 +49,8 @@ func @broadcast_add_ranked_components_r1x2(%arg0: tensor, %arg1: tensor, tensor) -> tensor // TODO: Overly broad shapes are being returned. Tighten the calculation // and update/extend these tests. - // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32} - %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + // CHECK: "mhlo_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32} + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor) -> tensor return %1 : tensor } diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir index cfd15e6..2c0e2d7 100644 --- a/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s // Check the non-broadcast case for each registered op, then just check a // representative op for detailed broadcast semantics. diff --git a/tests/legalize-control-flow.mlir b/tests/legalize-control-flow.mlir index 1d5faea..274792e 100644 --- a/tests/legalize-control-flow.mlir +++ b/tests/legalize-control-flow.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -xla-legalize-control-flow %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-legalize-control-flow %s -o - | FileCheck %s // CHECK-LABEL: func @while(%arg0: tensor) -> tensor { func @while(%arg0: tensor) -> tensor { diff --git a/tests/legalize-to-std.mlir b/tests/legalize-to-std.mlir index 774f926..37a6149 100644 --- a/tests/legalize-to-std.mlir +++ b/tests/legalize-to-std.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -xla-legalize-to-std %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-legalize-to-std %s -o - | FileCheck %s // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { diff --git a/tests/legalize_tanh_to_approximation.mlir b/tests/legalize_tanh_to_approximation.mlir index eaa3fdc..aa834d3 100644 --- a/tests/legalize_tanh_to_approximation.mlir +++ b/tests/legalize_tanh_to_approximation.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s func @tanh_f64(%arg0 : f64) -> f64 { %res = tanh %arg0 : f64 diff --git a/tests/lhlo-legalize-select-and-scatter.mlir b/tests/lhlo-legalize-select-and-scatter.mlir index b110d8d..a6bb876 100644 --- a/tests/lhlo-legalize-select-and-scatter.mlir +++ b/tests/lhlo-legalize-select-and-scatter.mlir @@ -1,6 +1,6 @@ // GenericAtomicRMWOp should contain only ops with no side effects. // Unfortunately, the legalization pattern for SelectAndScatterOp has to adapt -// to XLA LHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body. +// to LMHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body. // Lowering to STD dialect and store forwarding pass would be required to get // rid of them. This is exactly what is done in the real MLIR GPU pipeline, but // here we disable verification with `verify-each=0` to check the output IR. diff --git a/tests/lower-complex.mlir b/tests/lower-complex.mlir index 4db4d80..8d84e71 100644 --- a/tests/lower-complex.mlir +++ b/tests/lower-complex.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s +// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s // CHECK-LABEL: @add func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { diff --git a/tests/lower-general-dot.mlir b/tests/lower-general-dot.mlir index 3ee23da..36cb1fd 100644 --- a/tests/lower-general-dot.mlir +++ b/tests/lower-general-dot.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-test-lower-general-dot -split-input-file %s -o - | FileCheck %s // CHECK-LABEL: @testDebatch1 func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> { diff --git a/tests/materialize-broadcasts.mlir b/tests/materialize-broadcasts.mlir index 4fd8b3d..682987d 100644 --- a/tests/materialize-broadcasts.mlir +++ b/tests/materialize-broadcasts.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -mhlo-test-materialize-broadcasts -split-input-file %s -o - | FileCheck %s // CHECK-LABEL: @clampBroadcast // CHECK-SAME: (%[[MIN:.+]]: tensor, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor) diff --git a/tests/xla-hlo-fusion.mlir b/tests/mhlo-fusion.mlir similarity index 98% rename from tests/xla-hlo-fusion.mlir rename to tests/mhlo-fusion.mlir index 6dc079a..d349077 100644 --- a/tests/xla-hlo-fusion.mlir +++ b/tests/mhlo-fusion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s -mhlo-fusion -split-input-file | FileCheck %s // CHECK-LABEL: func @multi_outputs_same func @multi_outputs_same(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { diff --git a/tests/xla-transform-unranked-hlo.mlir b/tests/mhlo-transform-unranked.mlir similarity index 100% rename from tests/xla-transform-unranked-hlo.mlir rename to tests/mhlo-transform-unranked.mlir diff --git a/tests/sink-constants-to-control-flow.mlir b/tests/sink-constants-to-control-flow.mlir index 6a35239..f8b6b62 100644 --- a/tests/sink-constants-to-control-flow.mlir +++ b/tests/sink-constants-to-control-flow.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s +// RUN: mlir-hlo-opt %s -mhlo-sink-constants-to-control-flow | FileCheck %s // Tests sinking constants to a while loop. diff --git a/tests/unfuse_batch_norm.mlir b/tests/unfuse_batch_norm.mlir index 296bba7..b748ab4 100644 --- a/tests/unfuse_batch_norm.mlir +++ b/tests/unfuse_batch_norm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s +// RUN: mlir-hlo-opt -split-input-file -mhlo-test-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s // CHECK-LABEL: @batchNormInference_2D_inner_features // CHECK-SAME: %[[X:[^:[:space:]]+]]