From 8900222fede2fe3aea787132f3515a2fb7792523 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 7 Jul 2020 04:51:24 +0000 Subject: [PATCH] Rename `xla_hlo` dialect to `mhlo` This is part of the current refactoring of the HLO related dialect. `xla_hlo` will be reintroduced in a new form later. PiperOrigin-RevId: 319916753 --- include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td | 8 +- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h | 8 +- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 4 +- .../mhlo/transforms/map_hlo_to_lhlo_op.h | 12 +- .../mhlo/transforms/map_xla_to_scalar_op.h | 15 +- .../mlir-hlo/Dialect/mhlo/transforms/passes.h | 6 +- .../Dialect/mhlo/transforms/rewriters.h | 4 +- lib/Dialect/mhlo/IR/dialect_registration.cc | 2 +- lib/Dialect/mhlo/IR/hlo_ops.cc | 34 +- .../mhlo/transforms/chlo_legalize_to_hlo.cc | 71 ++- .../transforms/chlo_legalize_to_hlo_pass.cc | 4 +- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 96 ++--- .../mhlo/transforms/legalize_control_flow.cc | 29 +- .../mhlo/transforms/legalize_to_standard.cc | 21 +- .../mhlo/transforms/lower_general_dot.cc | 22 +- .../mhlo/transforms/materialize_broadcasts.cc | 4 +- .../transforms/materialize_broadcasts_pass.cc | 12 +- .../{xla_hlo_fusion.cc => mhlo_fusion.cc} | 12 +- .../sink_constants_to_control_flow.cc | 4 +- .../mhlo/transforms/unfuse_batch_norm.cc | 43 +- .../mhlo/transforms/unfuse_batch_norm_pass.cc | 6 +- .../mhlo/transforms/xla_legalize_to_linalg.cc | 72 ++-- .../transforms/xla_transform_unranked_hlo.cc | 8 +- tests/canonicalize.mlir | 292 ++++++------- tests/chlo_legalize_to_hlo_broadcasts.mlir | 56 +-- tests/concatenate.mlir | 2 +- tests/convert.mlir | 106 ++--- tests/hlo-legalize-to-lhlo.mlir | 66 +-- tests/hlo-legalize-to-linalg.mlir | 92 ++-- tests/inlining.mlir | 14 +- tests/legalize-control-flow.mlir | 70 +-- tests/legalize-to-std.mlir | 82 ++-- tests/lhlo_ops.mlir | 36 +- tests/lower-complex.mlir | 218 +++++----- tests/lower-general-dot.mlir | 24 +- tests/materialize-broadcasts.mlir | 8 +- tests/ops.mlir | 404 +++++++++--------- tests/reduce.mlir | 8 +- tests/reshape.mlir | 80 ++-- tests/reverse.mlir | 2 +- tests/sink-constants-to-control-flow.mlir | 70 +-- tests/transpose.mlir | 10 +- tests/tuple.mlir | 4 +- tests/unfuse_batch_norm.mlir | 66 +-- tests/xla-hlo-fusion.mlir | 98 ++--- tests/xla-transform-unranked-hlo.mlir | 32 +- 46 files changed, 1163 insertions(+), 1174 deletions(-) rename lib/Dialect/mhlo/transforms/{xla_hlo_fusion.cc => mhlo_fusion.cc} (98%) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 4cf48c6..6b515ac 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -17,12 +17,12 @@ limitations under the License. // These ops are not necessarily orthogonal or optimized for transformation but // for ease of expression in certain cases deemed important for client // libraries (i.e. implicit broadcasting, helper ops, etc). -// This dialect is considered to exist in addition to augment the xla_hlo +// This dialect is considered to exist in addition to augment the mhlo // dialect for ergonomic needs, not duplicate/replace it. // // The typical use of this dialect is for client libraries to be able to emit // less constrained ops and rely on the conversion framework to lower any -// xla_chlo ops to canonical xla_hlo ops. +// xla_chlo ops to canonical mhlo ops. // // See: https://www.tensorflow.org/xla/operation_semantics @@ -44,7 +44,7 @@ def HLOClient_Dialect : Dialect { let description = [{ This dialect contains ops that align closely with the API surface area of the XlaBuilder C++ API, where such ops have semantics that go beyond - what exists in the lower level dialects (such as `xla_hlo`). Essentially, + what exists in the lower level dialects (such as `mhlo`). Essentially, whenever the client library uses syntactic sugar or composition of multiple ops for an API call, this dialect tries to model the API call and provide conversion patterns to fully materialize into lower level @@ -65,7 +65,7 @@ class HLOClient_Op traits> : // broadcasting (via the broadcast_dimensions attribute) and implicit degenerate // shape broadcasting. // -// These correspond to operations in the xla_hlo dialect without the +// These correspond to operations in the mhlo dialect without the // "broadcast_" prefix, except that those ops require same-shaped operands and // results. // diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index d945900..ee55325 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -37,12 +37,12 @@ class OpBuilder; #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc" -namespace xla_hlo { +namespace mhlo { class XlaHloDialect : public Dialect { public: explicit XlaHloDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "xla_hlo"; } + static StringRef getDialectNamespace() { return "mhlo"; } // Registered hook to materialize a constant operation from a given attribute // value with the desired resultant type. @@ -82,7 +82,7 @@ class TokenType : public Type::TypeBase { // %1 = index_cast %0 : index to i64 // %2 = dim %arg0, 1 : memref // %3 = index_cast %2 : index to i64 -// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3) +// %4 = "mhlo.scalars_to_dimension_tensor"(%1, %3) // : (i64, i64) -> tensor<2xi64> // // and returns %4 as the shape value. @@ -93,7 +93,7 @@ LogicalResult deriveShapeFromFirstOperand( #define GET_OP_CLASSES #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" -} // end namespace xla_hlo +} // end namespace mhlo } // end namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 97a10d9..7cf1901 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -29,8 +29,8 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" def HLO_Dialect : Dialect { - let name = "xla_hlo"; - let cppNamespace = "xla_hlo"; + let name = "mhlo"; + let cppNamespace = "mhlo"; } class HLO_Op traits> : diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index 5e826c2..a05d1d3 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -22,7 +22,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { template struct HloToLhloOpImpl { @@ -31,10 +31,10 @@ struct HloToLhloOpImpl { template using HloToLhloOp = typename HloToLhloOpImpl::Type; -#define MAP_HLO_TO_LHLO(OpName) \ - template <> \ - struct HloToLhloOpImpl { \ - using Type = xla_lhlo::OpName; \ +#define MAP_HLO_TO_LHLO(OpName) \ + template <> \ + struct HloToLhloOpImpl { \ + using Type = xla_lhlo::OpName; \ } MAP_HLO_TO_LHLO(AbsOp); @@ -74,7 +74,7 @@ MAP_HLO_TO_LHLO(TanhOp); #undef MAP_HLO_TO_LHLO -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h index bb710a8..16a31f9 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h @@ -464,7 +464,7 @@ struct XlaOpToStdScalarOp { template ::value && - std::is_same, + std::is_same, std::false_type>::value>> static Value map(XlaOpTy op, ArrayRef result_types, ArrayRef args, OpBuilder* b, unsigned i = 0) { @@ -472,8 +472,8 @@ struct XlaOpToStdScalarOp { args, b); } - // Implementation for HLO ops except xla_hlo::CompareOp. - template , + // Implementation for HLO ops except mhlo::CompareOp. + template , typename = std::enable_if_t< !std::is_same::value && !std::is_same::value>> @@ -493,10 +493,11 @@ struct XlaOpToStdScalarOp { op.getLoc(), comparison_direction, result_types, args, b); } - // Implementation for xla_hlo::CompareOp. - template ::value>> - static Value map(xla_hlo::CompareOp op, ArrayRef result_types, + // Implementation for mhlo::CompareOp. + template ::value>> + static Value map(mhlo::CompareOp op, ArrayRef result_types, ArrayRef args, OpBuilder* b) { auto comparison_direction = op.comparison_direction(); return impl::MapXlaCompareOpToStdScalarOp( diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 3471587..b279e15 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -29,7 +29,7 @@ template class OperationPass; class Pass; -namespace xla_hlo { +namespace mhlo { /// Lowers HLO control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass(); @@ -55,10 +55,10 @@ std::unique_ptr> createTransformUnrankedHloPass(); // necessary to export to XLA. std::unique_ptr> createSinkConstantsToControlFlowPass(); -// fuse xla_hlo ops to kLoop/kInput fusion patterns +// fuse mhlo ops to kLoop/kInput fusion patterns std::unique_ptr> createXlaHloFusionPass(); -} // namespace xla_hlo +} // namespace mhlo namespace xla_lhlo { diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 606d510..fd0cc89 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -27,7 +27,7 @@ class LLVMTypeConverter; class LowerToLLVMOptions; class OwningRewritePatternList; class BufferAssignmentPlacer; -namespace xla_hlo { +namespace mhlo { // Collection of rewrite patterns for lowering a general dot product. void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns, @@ -73,7 +73,7 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context, void PopulateUnfuseBatchNormPatterns(MLIRContext *context, OwningRewritePatternList *patterns); -} // namespace xla_hlo +} // namespace mhlo namespace xla_lhlo { diff --git a/lib/Dialect/mhlo/IR/dialect_registration.cc b/lib/Dialect/mhlo/IR/dialect_registration.cc index 855c026..5e45b51 100644 --- a/lib/Dialect/mhlo/IR/dialect_registration.cc +++ b/lib/Dialect/mhlo/IR/dialect_registration.cc @@ -18,7 +18,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" // Static initialization for XLA dialect registration. -static mlir::DialectRegistration xla_hlo_ops; +static mlir::DialectRegistration mhlo_ops; static mlir::DialectRegistration xla_chlo_ops; static mlir::DialectRegistration xla_lhlo_ops; diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 0130f4b..90dac58 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -60,7 +60,7 @@ limitations under the License. namespace mlir { #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc" -namespace xla_hlo { +namespace mhlo { Operation* XlaHloDialect::materializeConstant(OpBuilder& builder, Attribute value, Type type, @@ -68,8 +68,7 @@ Operation* XlaHloDialect::materializeConstant(OpBuilder& builder, // HLO dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. if (value.isa()) - return builder.create(loc, type, - value.cast()); + return builder.create(loc, type, value.cast()); return nullptr; } @@ -167,7 +166,7 @@ void ConstOp::build(OpBuilder& builder, OperationState& result, } // TODO: support other XLA specific types. - assert(type && "unsupported attribute type for building xla_hlo.constant"); + assert(type && "unsupported attribute type for building mhlo.constant"); result.types.push_back(type); result.addAttribute("value", value); } @@ -387,7 +386,7 @@ static LogicalResult Verify(GetTupleElementOp op) { OpFoldResult GetTupleElementOp::fold(ArrayRef operands) { if (auto tupleOp = - dyn_cast_or_null(getOperand().getDefiningOp())) { + dyn_cast_or_null(getOperand().getDefiningOp())) { return tupleOp.getOperand(index().getLimitedValue()); } @@ -693,10 +692,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, } OpFoldResult ComplexOp::fold(ArrayRef operands) { - auto real_op = - dyn_cast_or_null(getOperand(0).getDefiningOp()); - auto imag_op = - dyn_cast_or_null(getOperand(1).getDefiningOp()); + auto real_op = dyn_cast_or_null(getOperand(0).getDefiningOp()); + auto imag_op = dyn_cast_or_null(getOperand(1).getDefiningOp()); if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) { return real_op.getOperand(); } @@ -727,7 +724,7 @@ void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) { OpFoldResult ImagOp::fold(ArrayRef operands) { if (auto complex_op = - dyn_cast_or_null(getOperand().getDefiningOp())) { + dyn_cast_or_null(getOperand().getDefiningOp())) { return complex_op.getOperand(1); } @@ -740,7 +737,7 @@ void RealOp::build(OpBuilder& builder, OperationState& state, Value val) { OpFoldResult RealOp::fold(ArrayRef operands) { if (auto complex_op = - dyn_cast_or_null(getOperand().getDefiningOp())) { + dyn_cast_or_null(getOperand().getDefiningOp())) { return complex_op.getOperand(0); } @@ -1148,7 +1145,7 @@ static LogicalResult Verify(MapOp op) { // RecvOp //===----------------------------------------------------------------------===// -// Checks that the result type is of the form `tuple` +// Checks that the result type is of the form `tuple` static LogicalResult Verify(RecvOp op) { auto result_ty = op.getResult().getType().cast(); auto subtypes = result_ty.getTypes(); @@ -2020,7 +2017,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" //===----------------------------------------------------------------------===// -// xla_hlo Dialect Interfaces +// mhlo Dialect Interfaces //===----------------------------------------------------------------------===// namespace { @@ -2032,7 +2029,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface { BlockAndValueMapping& valueMapping) const final { return true; } - // Operations in xla_hlo dialect are always legal to inline since they are + // Operations in mhlo dialect are always legal to inline since they are // pure. bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final { return true; @@ -2041,7 +2038,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface { } // end anonymous namespace //===----------------------------------------------------------------------===// -// xla_hlo Dialect Constructor +// mhlo Dialect Constructor //===----------------------------------------------------------------------===// XlaHloDialect::XlaHloDialect(MLIRContext* context) @@ -2061,8 +2058,7 @@ Type XlaHloDialect::parseType(DialectAsmParser& parser) const { if (parser.parseKeyword(&data_type)) return Type(); if (data_type == "token") return TokenType::get(getContext()); - parser.emitError(parser.getNameLoc()) - << "unknown xla_hlo type: " << data_type; + parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type; return nullptr; } @@ -2071,7 +2067,7 @@ void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const { os << "token"; return; } - os << ""; + os << ""; } //===----------------------------------------------------------------------===// @@ -2106,5 +2102,5 @@ LogicalResult deriveShapeFromFirstOperand( return success(); } -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index ed5282f..baaa8b8 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -30,7 +30,7 @@ namespace xla_chlo { namespace { // Converts binary ops that statically are determined to not broadcast directly -// to the corresponding xla_hlo non-broadcasting op. +// to the corresponding mhlo non-broadcasting op. template struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -63,7 +63,7 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern { }; // Converts a binary op with ranked broadcasting operands to explicitly -// broadcast and invoke the corresponding xla_hlo non-broadcasting op. +// broadcast and invoke the corresponding mhlo non-broadcasting op. // Note that dynamic broadcasting supported by this pattern is only valid for // "numpy" broadcasting semantics as defined here: // https://docs.scipy.org/doc/numpy/reference/ufuncs.html @@ -136,7 +136,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp // properly. auto lhs_broadcast_dimensions = llvm::to_vector<4>( llvm::seq(result_rank - lhs_type.getRank(), result_rank)); - Value broadcasted_lhs = rewriter.create( + Value broadcasted_lhs = rewriter.create( loc, RankedTensorType::get(result_type.getShape(), lhs_type.getElementType()), @@ -144,7 +144,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp rewriter.getI64TensorAttr(lhs_broadcast_dimensions)); auto rhs_broadcast_dimensions = llvm::to_vector<4>( llvm::seq(result_rank - rhs_type.getRank(), result_rank)); - Value broadcasted_rhs = rewriter.create( + Value broadcasted_rhs = rewriter.create( loc, RankedTensorType::get(result_type.getShape(), rhs_type.getElementType()), @@ -182,23 +182,21 @@ struct HloBinaryElementwiseAdaptor { }; struct HloComplexAdaptor { - static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op, - Type result_type, Value broadcasted_lhs, - Value broadcasted_rhs, - OpBuilder &builder) { - return builder.create(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs); + static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); } }; struct HloCompareAdaptor { - static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op, - Type result_type, Value broadcasted_lhs, - Value broadcasted_rhs, - OpBuilder &builder) { - return builder.create(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs, - from_op.comparison_direction()); + static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs, + from_op.comparison_direction()); } }; @@ -214,28 +212,27 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, HloBinaryElementwiseAdaptor>(context, \ patterns); - POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp); - POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp); - POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op); - POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp); - POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp); - POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp); - POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp); - POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp); - POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp); - POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp); - POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp); - POPULATE_BCAST(BroadcastShiftRightArithmeticOp, - xla_hlo::ShiftRightArithmeticOp); - POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp); - POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp); - POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp); + POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp); + POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp); + POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op); + POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp); + POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp); + POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp); + POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp); + POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp); + POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp); + POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp); + POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp); + POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp); + POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp); + POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp); + POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp); // Broadcasting ops requiring special construction. - PopulateForBinaryOp(context, patterns); - PopulateForBinaryOp(context, patterns); + PopulateForBinaryOp( + context, patterns); + PopulateForBinaryOp( + context, patterns); } } // namespace xla_chlo diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index 0e5d5b1..951a418 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -32,8 +32,8 @@ struct TestChloLegalizeToHloPass OwningRewritePatternList conversionPatterns; conversionTarget.addIllegalDialect(); - // Consider the xla_hlo dialect legal for tests. - conversionTarget.addLegalDialect(); + // Consider the mhlo dialect legal for tests. + conversionTarget.addLegalDialect(); // The conversion uses helpers from the Standard dialect. conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 6fd5805..83f60c5 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -37,7 +37,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { template @@ -128,20 +128,20 @@ class HloToLhloOpConverter : public BaseOpConversion { op->getLoc(), result.value(), results_shape.front(), &rewriter)); } } - rewriter.create>(op->getLoc(), llvm::None, - buffer_args, op->getAttrs()); + rewriter.create>(op->getLoc(), llvm::None, + buffer_args, op->getAttrs()); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); return success(); } }; struct HloToLhloDynamicBroadcastInDimOpConverter - : public BaseOpConversion { + : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - xla_hlo::DynamicBroadcastInDimOp op, ArrayRef operands, + mhlo::DynamicBroadcastInDimOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); Value resultBuffer = InsertDynamicAllocAndDealloc( @@ -162,7 +162,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter // and size of the target dimension if size-1 dimension expansion is // necessary. xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( - xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { + mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { auto loc = op.getLoc(); auto operand_type = operand.getType().cast(); auto operand_shape = operand_type.getShape(); @@ -220,12 +220,12 @@ struct HloToLhloDynamicBroadcastInDimOpConverter } }; -struct HloToLhloReduceOpConverter : public BaseOpConversion { +struct HloToLhloReduceOpConverter : public BaseOpConversion { public: - using BaseOpConversion::BaseOpConversion; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - xla_hlo::ReduceOp op, ArrayRef operands, + mhlo::ReduceOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); // TODO(b/137624192) Implement variadic reduce. @@ -314,10 +314,10 @@ class HloToLhloTensorStoreOpConverter // "xla_lhlo.fusion"() ({ // %0 = tensor_load %arg1 : memref<2x2xf32> // %1 = tensor_load %arg2 : memref<2x2xf32> -// %2 = "xla_hlo.add"(%0, %1) : +// %2 = "mhlo.add"(%0, %1) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // %3 = tensor_load %arg0 : memref<2x2xf32> -// %4 = "xla_hlo.multiply"(%2, %3) : +// %4 = "mhlo.multiply"(%2, %3) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // tensor_store %4, %arg3 : memref<2x2xf32> // "xla_lhlo.terminator"() : () -> () @@ -344,8 +344,8 @@ class HloToLhloTensorStoreOpConverter // FuncOp signature conversion example: // // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { -// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> -// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>, +// %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> +// tensor<4xf32> %1 = "mhlo.add"(%arg0, %0) : (tensor<4xf32>, // tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> // } // @@ -388,7 +388,7 @@ struct HloLegalizeToLhlo target.addIllegalOp(); target.addLegalOp(); target.addLegalOp(); - target.addIllegalDialect(); + target.addIllegalDialect(); BufferAssignmentTypeConverter converter; target.addDynamicallyLegalOp([&](FuncOp op) { @@ -442,38 +442,38 @@ void populateHLOToLHLOConversionPattern( // clang-format off patterns->insert< HloToLhloDynamicBroadcastInDimOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloReduceOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter @@ -489,5 +489,5 @@ std::unique_ptr> createLegalizeToLhloPass( static PassRegistration legalize_pass( "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect"); -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc index 87910af..83cd9c4 100644 --- a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -35,7 +35,7 @@ limitations under the License. using mlir::PassRegistration; namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { struct LegalizeControlFlow : public mlir::PassWrapper { @@ -51,7 +51,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block, OpBuilder* builder) { for (auto& old_block : region->getBlocks()) { Block* block = mapper.lookup(&old_block); - auto return_op = dyn_cast(block->getTerminator()); + auto return_op = dyn_cast(block->getTerminator()); if (!return_op) continue; builder->setInsertionPointToEnd(block); builder->create(loc, target_block, return_op.getOperands()); @@ -61,7 +61,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block, return success(); } -LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) { +LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) { Operation* op_inst = if_op.getOperation(); mlir::OpBuilder builder(if_op); auto orig_block = op_inst->getBlock(); @@ -106,13 +106,13 @@ LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) { return success(); } -LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { +LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { // Converts an XLA while loop into control flow. This generates a set of MLIR // blocks and branches, along with inlining the regions provided by the XLA // while loop. The structure should be similar to below: // // - // %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}} + // %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}} // auto* op_inst = while_op.getOperation(); mlir::OpBuilder builder(while_op); @@ -147,7 +147,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { // extract_element and conditional branch. This changes the block below: // ^cond(%0): // - // "xla_hlo".return(%1) + // "mhlo".return(%1) // // Into: // ^cond(%0): @@ -156,14 +156,14 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { // cond_br %2, ^body(%0), ^tail(%0) // Branch. builder.setInsertionPointToStart(cond_block); - // Replace the xla_hlo::ReturnOp with a branch back to the condition block. - // This is required as the xla_hlo::ReturnOp is used to mark the end of a + // Replace the mhlo::ReturnOp with a branch back to the condition block. + // This is required as the mhlo::ReturnOp is used to mark the end of a // block for regions nested inside of a operations (MLIR ReturnOp cannot be // nested within an non-function region). for (auto& block : while_op.cond()) { auto new_block = mapper.lookup(&block); - auto return_op = dyn_cast(new_block->getTerminator()); + auto return_op = dyn_cast(new_block->getTerminator()); if (!return_op) continue; builder.setInsertionPointToEnd(new_block); @@ -183,7 +183,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { // conditional block. This changes the block below: // ^body(%0): // - // "xla_hlo".return(%1) + // "mhlo".return(%1) // // Into: // ^body(%0): @@ -191,8 +191,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { // br ^cond(%0) // Branch. for (auto& block : while_op.body()) { auto new_block = mapper.lookup(&block); - auto return_op = - dyn_cast(new_block->getTerminator()); + auto return_op = dyn_cast(new_block->getTerminator()); if (!return_op) continue; builder.setInsertionPointToEnd(new_block); builder.create(loc, cond_block, return_op.getOperands()); @@ -224,14 +223,14 @@ void LegalizeControlFlow::runOnFunction() { } } } // namespace -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir std::unique_ptr> -mlir::xla_hlo::createLegalizeControlFlowPass() { +mlir::mhlo::createLegalizeControlFlowPass() { return std::make_unique(); } -static PassRegistration legalize_cf_pass( +static PassRegistration legalize_cf_pass( "xla-legalize-control-flow", "Legalize from XLA control flow to MLIR control flow"); diff --git a/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/lib/Dialect/mhlo/transforms/legalize_to_standard.cc index f4e7b49..0e59727 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_standard.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_standard.cc @@ -28,14 +28,14 @@ namespace mlir { namespace { #include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc" } // end anonymous namespace -namespace xla_hlo { +namespace mhlo { namespace { -class CompareIConvert : public OpRewritePattern { +class CompareIConvert : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(xla_hlo::CompareOp op, + LogicalResult matchAndRewrite(mhlo::CompareOp op, PatternRewriter &rewriter) const override { auto lhs = op.lhs(); auto rhs = op.rhs(); @@ -68,11 +68,11 @@ class CompareIConvert : public OpRewritePattern { } }; -class CompareFConvert : public OpRewritePattern { +class CompareFConvert : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(xla_hlo::CompareOp op, + LogicalResult matchAndRewrite(mhlo::CompareOp op, PatternRewriter &rewriter) const override { auto lhs = op.lhs(); auto rhs = op.rhs(); @@ -109,11 +109,11 @@ class CompareFConvert : public OpRewritePattern { // convert the integer constant to iota result type. For complex types, the real // part is replaced with the generated constant and the imaginary part is // replaced with zero tensor. -class ConvertIotaOp : public OpRewritePattern { +class ConvertIotaOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(xla_hlo::IotaOp op, + LogicalResult matchAndRewrite(mhlo::IotaOp op, PatternRewriter &rewriter) const override { auto output_type = op.getType().cast(); auto output_size = output_type.getNumElements(); @@ -168,8 +168,7 @@ class ConvertIotaOp : public OpRewritePattern { loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0))); auto imag_zeroes = rewriter.create(loc, int_or_float_shape_ty, zeroes); - rewriter.replaceOpWithNewOp(op, iota_const, - imag_zeroes); + rewriter.replaceOpWithNewOp(op, iota_const, imag_zeroes); return success(); } }; @@ -197,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, /// Perform the lowering to standard dialect. void LegalizeToStandard::runOnFunction() { OwningRewritePatternList patterns; - mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext()); + mlir::mhlo::PopulateXlaToStdPatterns(&patterns, &getContext()); applyPatternsAndFoldGreedily(getFunction(), patterns); } static PassRegistration legalize_pass( "xla-legalize-to-std", "Legalize from XLA dialect to standard dialect"); -} // end namespace xla_hlo +} // end namespace mhlo } // end namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/lib/Dialect/mhlo/transforms/lower_general_dot.cc index 4b38c34..40f3314 100644 --- a/lib/Dialect/mhlo/transforms/lower_general_dot.cc +++ b/lib/Dialect/mhlo/transforms/lower_general_dot.cc @@ -84,14 +84,14 @@ Value TransposeReshape(Value arg, mlir::Location loc, transposed_shape.push_back(arg_shape[val]); } auto transpose_type = RankedTensorType::get(transposed_shape, element_type); - auto transpose_result = rewriter->create( + auto transpose_result = rewriter->create( loc, transpose_type, arg, transpose_permutation_attr); // Return the final result. auto reshaped_type = RankedTensorType::get({left_size, right_size}, element_type); - return rewriter->create(loc, reshaped_type, - transpose_result); + return rewriter->create(loc, reshaped_type, + transpose_result); } Value ProcessDotArg(Value arg, mlir::Location loc, @@ -125,8 +125,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc, return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter); } -struct GeneralDotConvert - : public OpRewritePattern { +struct GeneralDotConvert : public OpRewritePattern { // Attempts to lower a General Dot operator to a standard Dot operator. // General dots include batching dimensions and can have collapsing // dimensions along any axis. Inserting correctly arrange transpose and @@ -138,7 +137,7 @@ struct GeneralDotConvert explicit GeneralDotConvert(MLIRContext *context) : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op, + LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op, PatternRewriter &rewriter) const override { auto dot_element_type = mlir::getElementTypeOrSelf(op); @@ -162,11 +161,11 @@ struct GeneralDotConvert auto new_dot_type = RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); - auto new_dot_op = rewriter.create( + auto new_dot_op = rewriter.create( op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config())); - rewriter.replaceOpWithNewOp(op, op.getType(), - new_dot_op); + rewriter.replaceOpWithNewOp(op, op.getType(), + new_dot_op); return success(); } }; @@ -176,15 +175,14 @@ struct LegalizeGeneralDot /// Lower all general dots that can be represented as a non-batched matmul. void runOnFunction() override { OwningRewritePatternList patterns; - mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, - &getContext()); + mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext()); applyPatternsAndFoldGreedily(getFunction(), patterns); } }; } // namespace -void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns( +void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns( OwningRewritePatternList *patterns, MLIRContext *ctx) { patterns->insert(ctx); } diff --git a/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc b/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc index 074f97c..8abc099 100644 --- a/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc +++ b/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc @@ -23,7 +23,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { @@ -86,5 +86,5 @@ void PopulateMaterializeBroadcastsPatterns(MLIRContext *context, patterns->insert(context); } -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc index 2106ec3..1f55bfa 100644 --- a/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc +++ b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc @@ -23,7 +23,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { @@ -33,7 +33,7 @@ struct TestMaterializeBroadcastsPass ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns; - // Consider the xla_hlo dialect legal for tests. + // Consider the mhlo dialect legal for tests. conversionTarget.addLegalDialect(); // The conversion uses helpers from the Standard dialect. conversionTarget.addLegalDialect(); @@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass } // namespace -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir -static mlir::PassRegistration - pass("test-xla-materialize-broadcasts", - "Test pass for materializing 'broadcast_dimensions' attributes"); +static mlir::PassRegistration pass( + "test-xla-materialize-broadcasts", + "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc b/lib/Dialect/mhlo/transforms/mhlo_fusion.cc similarity index 98% rename from lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc rename to lib/Dialect/mhlo/transforms/mhlo_fusion.cc index 2cde14a..568bceb 100644 --- a/lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc +++ b/lib/Dialect/mhlo/transforms/mhlo_fusion.cc @@ -60,7 +60,7 @@ limitations under the License. // shape dialect once it is ready. namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { using llvm::EquivalenceClasses; @@ -544,7 +544,7 @@ struct XlaHloFusion : public mlir::PassWrapper { } FusionOp fusion = - b.create(fused_loc, output_types, inputs); + b.create(fused_loc, output_types, inputs); Region& region = fusion.fused_computation(); region.push_back(new Block); Block& block = region.front(); @@ -552,7 +552,7 @@ struct XlaHloFusion : public mlir::PassWrapper { op->moveBefore(&block, block.end()); } b.setInsertionPoint(&block, block.end()); - b.create(fused_loc, outputs); + b.create(fused_loc, outputs); for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) { Value output = std::get<0>(output_and_result); @@ -572,8 +572,8 @@ std::unique_ptr> createXlaHloFusion() { return std::make_unique(); } -static PassRegistration xla_hlo_fusion_pass( - "xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns."); +static PassRegistration mhlo_fusion_pass( + "xla-hlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns."); -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc index 666ca53..dd2e663 100644 --- a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -23,7 +23,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { @@ -81,5 +81,5 @@ std::unique_ptr> createSinkConstantsToControlFlowPass() { return std::make_unique(); } -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc index b0fc6a1..5028e28 100644 --- a/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc @@ -25,7 +25,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { @@ -40,12 +40,12 @@ Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type, auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64)); auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim}); if (shape_value) { - return rewriter.createOrFold( + return rewriter.createOrFold( loc, result_type, value_1d, shape_value, dims); } assert(result_type.hasStaticShape()); - return rewriter.create(loc, result_type, value_1d, - dims); + return rewriter.create(loc, result_type, value_1d, + dims); } // Calculate the shape value of operand, assuming it is a dynamic shape with @@ -89,25 +89,25 @@ Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr, auto epsilon_tensor_attr = DenseElementsAttr::get(scalar_type, {epsilon_attr.cast()}); Value epsilon = - rewriter.create(op->getLoc(), epsilon_tensor_attr); + rewriter.create(op->getLoc(), epsilon_tensor_attr); auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64)); auto dims = DenseIntElementsAttr::get(dims_type, SmallVector{}); if (broadcast_to_type.hasStaticShape()) { - return rewriter.create( + return rewriter.create( op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims); } Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter); - return rewriter.createOrFold( + return rewriter.createOrFold( op->getLoc(), broadcast_to_type, epsilon, shape_value, /*broadcast_dims=*/dims); } class UnfuseBatchNormInferencePattern - : public OpRewritePattern { + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op, + LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op, PatternRewriter& rewriter) const override { // Enforce type invariants. // Note that we deduce the actual element type from the variance, @@ -132,9 +132,9 @@ class UnfuseBatchNormInferencePattern if (!epsilon) { return failure(); } - Value stddev = rewriter.create(bn_op.getLoc(), - bn_op.variance(), epsilon); - stddev = rewriter.create(bn_op.getLoc(), stddev); + Value stddev = + rewriter.create(bn_op.getLoc(), bn_op.variance(), epsilon); + stddev = rewriter.create(bn_op.getLoc(), stddev); // Broadcast all terms. Value shape_value; @@ -156,14 +156,13 @@ class UnfuseBatchNormInferencePattern // Compute: // scale * (input - mean) / stddev + offset - Value result = rewriter.create( - bn_op.getLoc(), bn_op.operand(), broadcast_mean); - result = rewriter.create(bn_op.getLoc(), result, - broadcast_scale); - result = rewriter.create(bn_op.getLoc(), result, - broadcast_stddev); - rewriter.replaceOpWithNewOp(bn_op, result, - broadcast_offset); + Value result = rewriter.create(bn_op.getLoc(), bn_op.operand(), + broadcast_mean); + result = + rewriter.create(bn_op.getLoc(), result, broadcast_scale); + result = + rewriter.create(bn_op.getLoc(), result, broadcast_stddev); + rewriter.replaceOpWithNewOp(bn_op, result, broadcast_offset); return success(); } @@ -180,5 +179,5 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext* context, patterns->insert(context); } -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc index 179b63c..4a5b5fd 100644 --- a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc @@ -23,7 +23,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { @@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass } // namespace -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir -static mlir::PassRegistration pass( +static mlir::PassRegistration pass( "test-xla-unfuse-batch-norm", "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc index 66a9aaa..9dd69b8 100644 --- a/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc @@ -182,7 +182,7 @@ struct ConvToLinalgConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; // This code has been adapted from IREE's - // (https://github.com/google/iree/) xla_hlo -> linalg conversion. + // (https://github.com/google/iree/) mhlo -> linalg conversion. LogicalResult matchAndRewrite( xla_lhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { @@ -348,14 +348,14 @@ class BroadcastConverter class HloBroadcastInDimConverter : public DataMovementOpConverter { + mhlo::BroadcastInDimOp, false> { public: using DataMovementOpConverter::DataMovementOpConverter; static SmallVector getIndexingMaps( - xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) { + mhlo::BroadcastInDimOp broadcastOp, Builder* b) { auto resultType = getXLAOpResultType(broadcastOp); auto operandType = broadcastOp.operand().getType().template cast(); @@ -845,7 +845,7 @@ struct HloLegalizeToLinalg target.addLegalDialect(); auto func = getFunction(); - xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); + mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); if (failed(applyPartialConversion(func, target, patterns, nullptr))) { signalPassFailure(); } @@ -863,40 +863,40 @@ static PassRegistration legalize_lhlo_pass( "lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect"); } // namespace xla_lhlo -namespace xla_hlo { +namespace mhlo { void populateHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { - patterns->insert, + patterns->insert, HloBroadcastInDimConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - ReshapeOpConverter, - ReverseConverter, - TransposeConverter>(context); + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + ReshapeOpConverter, + ReverseConverter, + TransposeConverter>(context); } std::unique_ptr> createLegalizeHloToLinalgPass() { @@ -905,5 +905,5 @@ std::unique_ptr> createLegalizeHloToLinalgPass() { static PassRegistration legalize_hlo_pass( "hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect"); -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc index fde9cef..c238085 100644 --- a/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc @@ -28,7 +28,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_hlo { +namespace mhlo { namespace { // TODO(frgossen): Make it variadic. @@ -69,7 +69,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern { rewriter.create(loc, numElementsAsIndex); auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, operandTy.getElementType()); - Value flatOperand = rewriter.create( + Value flatOperand = rewriter.create( loc, flatTensorTy, operand, flatShapeAsDimTensor); // Generate IR for the actual operation. @@ -80,7 +80,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern { rewriter.getIndexType()); Value shapeAsExtentTensor = rewriter.create(loc, extentTensorTy, shape); - Value result = rewriter.create( + Value result = rewriter.create( loc, operandTy, flatResult, shapeAsExtentTensor); rewriter.replaceOp(op, result); @@ -184,5 +184,5 @@ static PassRegistration transform_unranked_hlo_pass( "transform-unranked-hlo", "Realize element-wise operations on ranked tensors where possible"); -} // namespace xla_hlo +} // namespace mhlo } // namespace mlir diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 90dd8f1..1124a46 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -2,107 +2,107 @@ // CHECK-LABEL: add_fold func @add_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> - %1 = xla_hlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<[6, 8, 10, 12]> - %2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + %0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> + %1 = mhlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64> + // CHECK: mhlo.constant dense<[6, 8, 10, 12]> + %2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: add_scalar_fold func @add_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<1> : tensor<4xi64> - %1 = xla_hlo.constant dense<5> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<6> - %2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + %0 = mhlo.constant dense<1> : tensor<4xi64> + %1 = mhlo.constant dense<5> : tensor<4xi64> + // CHECK: mhlo.constant dense<6> + %2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: add_fold_float func @add_fold_float() -> tensor<4xf64> { - %0 = xla_hlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> - %1 = xla_hlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64> - // CHECK: xla_hlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]> - %2 = "xla_hlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + %0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> + %1 = mhlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64> + // CHECK: mhlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]> + %2 = "mhlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) return %2 : tensor<4xf64> } // CHECK-LABEL: sub_scalar_fold func @sub_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<5> : tensor<4xi64> - %1 = xla_hlo.constant dense<1> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<4> - %2 = "xla_hlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + %0 = mhlo.constant dense<5> : tensor<4xi64> + %1 = mhlo.constant dense<1> : tensor<4xi64> + // CHECK: mhlo.constant dense<4> + %2 = "mhlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: multiply_scalar_fold func @multiply_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<5> : tensor<4xi64> - %1 = xla_hlo.constant dense<3> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<15> - %2 = "xla_hlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + %0 = mhlo.constant dense<5> : tensor<4xi64> + %1 = mhlo.constant dense<3> : tensor<4xi64> + // CHECK: mhlo.constant dense<15> + %2 = "mhlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: divide_scalar_fold func @divide_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<7> : tensor<4xi64> - %1 = xla_hlo.constant dense<5> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<1> - %2 = "xla_hlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + %0 = mhlo.constant dense<7> : tensor<4xi64> + %1 = mhlo.constant dense<5> : tensor<4xi64> + // CHECK: mhlo.constant dense<1> + %2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: divide_fold_float func @divide_fold_float() -> tensor<4xf64> { - %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> - %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> - // CHECK: xla_hlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]> - %2 = "xla_hlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + %0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: mhlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]> + %2 = "mhlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) return %2 : tensor<4xf64> } // CHECK-LABEL: max_scalar_fold func @max_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<7> : tensor<4xi64> - %1 = xla_hlo.constant dense<5> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<7> - %2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + %0 = mhlo.constant dense<7> : tensor<4xi64> + %1 = mhlo.constant dense<5> : tensor<4xi64> + // CHECK: mhlo.constant dense<7> + %2 = "mhlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: max_fold_float func @max_fold_float() -> tensor<4xf64> { - %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> - %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> - // CHECK: xla_hlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]> - %2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + %0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: mhlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]> + %2 = "mhlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) return %2 : tensor<4xf64> } // CHECK-LABEL: min_scalar_fold func @min_scalar_fold() -> tensor<4xi64> { - %0 = xla_hlo.constant dense<7> : tensor<4xi64> - %1 = xla_hlo.constant dense<-5> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<-5> - %2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + %0 = mhlo.constant dense<7> : tensor<4xi64> + %1 = mhlo.constant dense<-5> : tensor<4xi64> + // CHECK: mhlo.constant dense<-5> + %2 = "mhlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: min_fold_float func @min_fold_float() -> tensor<4xf64> { - %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> - %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> - // CHECK: xla_hlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]> - %2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + %0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]> + %2 = "mhlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) return %2 : tensor<4xf64> } // CHECK-LABEL: concatenate_noop func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> - %0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32> + %0 = "mhlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32> // CHECK: return [[ARG]] return %0 : tensor<4xi32> @@ -112,7 +112,7 @@ func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> { // CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32> // CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32> - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32> // CHECK: return [[ARG0]] return %0 : tensor<4xi32> @@ -120,34 +120,34 @@ func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> // CHECK-LABEL: concatenate_empty_bool func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> { - // CHECK: xla_hlo.constant - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> + // CHECK: mhlo.constant + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> return %0 : tensor<0xi1> } // CHECK-LABEL: concatenate_empty_int func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> { - // CHECK: xla_hlo.constant - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32> + // CHECK: mhlo.constant + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32> return %0 : tensor<0xi32> } // CHECK-LABEL: concatenate_empty_float func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { - // CHECK: xla_hlo.constant - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> + // CHECK: mhlo.constant + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> return %0 : tensor<0xf32> } // CHECK-LABEL: concatenate_const_1D func @concatenate_const_1D() -> tensor<4xi32> { - // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]> - %0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32> - %1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32> - %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32> + // CHECK: [[VAL:%.+]]= mhlo.constant dense<[0, 1, 2, 3]> + %0 = mhlo.constant dense<[0, 1]> : tensor<2xi32> + %1 = mhlo.constant dense<[2, 3]> : tensor<2xi32> + %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32> // CHECK: return [[VAL]] return %2 : tensor<4xi32> @@ -155,11 +155,11 @@ func @concatenate_const_1D() -> tensor<4xi32> { // CHECK-LABEL: concatenate_const_1D_float func @concatenate_const_1D_float() -> tensor<4xf32> { - // CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> + // CHECK: [[VAL:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> - %0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32> - %1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32> - %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32> + %0 = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32> + %1 = mhlo.constant dense<[2.0, 3.0]> : tensor<2xf32> + %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32> // CHECK: return [[VAL]] return %2 : tensor<4xf32> @@ -167,12 +167,12 @@ func @concatenate_const_1D_float() -> tensor<4xf32> { // CHECK-LABEL: concatenate_const_2D_vertical func @concatenate_const_2D_vertical() -> tensor<2x2xi32> { - // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ + // CHECK: [[VAL:%.+]]= mhlo.constant dense<[ // CHECK-SAME: [0, 1], [2, 3] // CHECK-SAME: ]> - %0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32> - %1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32> - %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> + %0 = mhlo.constant dense<[[0, 1]]> : tensor<1x2xi32> + %1 = mhlo.constant dense<[[2, 3]]> : tensor<1x2xi32> + %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> // CHECK: return [[VAL]] return %2 : tensor<2x2xi32> @@ -180,12 +180,12 @@ func @concatenate_const_2D_vertical() -> tensor<2x2xi32> { // CHECK-LABEL: concatenate_const_2D_horizontal func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { - // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ + // CHECK: [[VAL:%.+]]= mhlo.constant dense<[ // CHECK-SAME: [0, 2], [1, 3] // CHECK-SAME: ]> - %0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32> - %1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32> - %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> + %0 = mhlo.constant dense<[[0], [1]]> : tensor<2x1xi32> + %1 = mhlo.constant dense<[[2], [3]]> : tensor<2x1xi32> + %2 = "mhlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> // CHECK: return [[VAL]] return %2 : tensor<2x2xi32> @@ -193,40 +193,40 @@ func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { // CHECK-LABEL: dynamic_slice_variable_start func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { - // CHECK: "xla_hlo.dynamic-slice" - %1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + // CHECK: "mhlo.dynamic-slice" + %1 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %1 : tensor<1x4xi32> } // CHECK-LABEL: dynamic_slice_constant_start func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0) + // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) // CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64> // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} // CHECK: return %[[RESULT]] : tensor<2xi32> - %0 = xla_hlo.constant dense<1> : tensor - %1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> + %0 = mhlo.constant dense<1> : tensor + %1 = "mhlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> return %1 : tensor<2xi32> } // CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor { - // CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0) + // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) // CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64> // CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64> // CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64> // CHECK: return %[[RESULT]] : tensor - %0 = xla_hlo.constant dense<1> : tensor - %1 = xla_hlo.constant dense<0> : tensor - %2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + %2 = "mhlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: slice_2D_noop // CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64> func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { - %0 = "xla_hlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>) + %0 = "mhlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>) // CHECK-NEXT: return [[ARG]] return %0 : tensor<2x2xi64> @@ -234,80 +234,80 @@ func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { // CHECK-LABEL: slice_1D_fold func @slice_1D_fold() -> tensor<2xi64> { - %0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<[7, 9]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) + %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> + // CHECK: mhlo.constant dense<[7, 9]> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) return %1 : tensor<2xi64> } // CHECK-LABEL: slice_1D_fp func @slice_1D_fp() -> tensor<2xf32> { - %0 = xla_hlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32> - // CHECK: xla_hlo.constant dense<[7.000000e+00, 9.000000e+00]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>) + %0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32> + // CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>) return %1 : tensor<2xf32> } // CHECK-LABEL: slice_1D_strided_fold func @slice_1D_strided_fold() -> tensor<2xi64> { - %0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> - // CHECK: xla_hlo.constant dense<[7, 10]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) + %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> + // CHECK: mhlo.constant dense<[7, 10]> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) return %1 : tensor<2xi64> } // CHECK-LABEL: slice_2D_fold func @slice_2D_fold() -> tensor<2x2xi64> { - %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> - // CHECK-NEXT: xla_hlo.constant dense<[ + %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: mhlo.constant dense<[ // CHECK-SAME: [6, 7], // CHECK-SAME: [10, 11] // CHECK-SAME: ]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>) + %1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>) return %1 : tensor<2x2xi64> } // CHECK-LABEL: slice_2D_fold_horizontal func @slice_2D_fold_horizontal() -> tensor<1x4xi64> { - %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> - // CHECK-NEXT: xla_hlo.constant dense<[ + %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: mhlo.constant dense<[ // CHECK-SAME: [0, 1, 2, 3] // CHECK-SAME: ]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>) + %1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>) return %1 : tensor<1x4xi64> } // CHECK-LABEL: slice_2D_fold_vertical func @slice_2D_fold_vertical() -> tensor<4x1xi64> { - %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> - // CHECK-NEXT: xla_hlo.constant dense<[ + %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: mhlo.constant dense<[ // CHECK-SAME: [2], [6], [10], [14] // CHECK-SAME: ]> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>) + %1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>) return %1 : tensor<4x1xi64> } // CHECK-LABEL: slice_concat_fold_first func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) // CHECK: return %arg0 return %1 : tensor<1x5xf32> } // CHECK-LABEL: slice_concat_fold_second func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) // CHECK: return %arg1 return %1 : tensor<1x5xf32> } // CHECK-LABEL: slice_concat_fold_second_with_slice func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - // CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32> - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>) + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>) // CHECK: return [[SLICE]] return %1 : tensor<1x4xf32> @@ -315,9 +315,9 @@ func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor< // CHECK-LABEL: slice_concat_fold_middle func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> - // CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>) + %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> + // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + %1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>) // CHECK: return [[SLICE]] return %1 : tensor<1x5xf32> @@ -325,11 +325,11 @@ func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, % // CHECK-LABEL: slice_concat_fold_two func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> { - // CHECK: [[CONCAT:%.+]] = "xla_hlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64} - %0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> + // CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64} + %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> - // CHECK: [[SLICE:%.+]] = "xla_hlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>) + // CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + %1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>) // CHECK: return [[SLICE]] return %1 : tensor<2x5xf32> @@ -338,72 +338,72 @@ func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg // CHECK-LABEL: func @broadcast_in_dim_identity func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { // CHECK: return %arg0 - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> return %0 : tensor<2x3x4xf32> } // CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { - // CHECK: xla_hlo.broadcast_in_dim - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32> + // CHECK: mhlo.broadcast_in_dim + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } // CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: xla_hlo.broadcast_in_dim - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: mhlo.broadcast_in_dim + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { - // CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32> - %0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> + // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32> + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> // CHECK: return %[[RESULT]] : tensor<5x4xf32> return %0 : tensor<5x4xf32> } // CHECK-LABEL: @complex_expand_fold func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex>) - %1 = "xla_hlo.real"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) - %2 = "xla_hlo.imag"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex>) + %1 = "mhlo.real"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + %2 = "mhlo.imag"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) // CHECK: return %arg0, %arg1 return %1, %2 : tensor<4xf32>, tensor<4xf32> } // CHECK-LABEL: @complex_collapse_fold func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { - %0 = "xla_hlo.real"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) - %1 = "xla_hlo.imag"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) - %2 = "xla_hlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + %0 = "mhlo.real"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + %1 = "mhlo.imag"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + %2 = "mhlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> // CHECK: return %arg0 return %2 : tensor<4xcomplex> } // CHECK-LABEL: @dynamic_iota_is_static func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> { - // CHECK: [[RESULT:%.*]] = "xla_hlo.iota" + // CHECK: [[RESULT:%.*]] = "mhlo.iota" // CHECK: return [[RESULT]] - %0 = "xla_hlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32> + %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32> return %0 : tensor<4xi32> } // CHECK-LABEL: @iota_not_lowered_to_constant func @iota_not_lowered_to_constant() -> tensor<4xi32> { - // CHECK: [[RESULT:%.*]] = "xla_hlo.iota" + // CHECK: [[RESULT:%.*]] = "mhlo.iota" // CHECK: return [[RESULT]] - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> return %0 : tensor<4xi32> } // CHECK-LABEL: @unary_einsum func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { - // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor - // CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"} - %0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> + // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: "mhlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"} + %0 = "mhlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -411,30 +411,30 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { // CHECK: return [[ARG]] - %0 = "xla_hlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32> + %0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } // CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> { - // CHECK: xla_hlo.reshape - %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32> + // CHECK: mhlo.reshape + %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32> return %0 : tensor<4x1xf32> } // CHECK-LABEL: do_not_dce_while_with_outfeed func @do_not_dce_while_with_outfeed(%arg0: tensor) -> tensor { - // CHECK: xla_hlo.while - %0 = "xla_hlo.while"(%arg0) ( { + // CHECK: mhlo.while + %0 = "mhlo.while"(%arg0) ( { ^bb0(%arg1: tensor): - %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): - %1 = "xla_hlo.create_token"() : () -> !xla_hlo.token + %1 = "mhlo.create_token"() : () -> !mhlo.token // Side-effecting op outfeed present inside while. - %2 = "xla_hlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor, !xla_hlo.token) -> !xla_hlo.token - "xla_hlo.return"(%arg1) : (tensor) -> () + %2 = "mhlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor, !mhlo.token) -> !mhlo.token + "mhlo.return"(%arg1) : (tensor) -> () }) : (tensor) -> tensor return %arg0 : tensor @@ -442,15 +442,15 @@ func @do_not_dce_while_with_outfeed(%arg0: tensor) -> tensor { // CHECK-LABEL: dce_while_without_side_effect func @dce_while_without_side_effect(%arg0: tensor) -> tensor { - // CHECK-NOT: xla_hlo.while - %0 = "xla_hlo.while"(%arg0) ( { + // CHECK-NOT: mhlo.while + %0 = "mhlo.while"(%arg0) ( { ^bb0(%arg1: tensor): - %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): - %1 = "xla_hlo.create_token"() : () -> !xla_hlo.token - "xla_hlo.return"(%arg1) : (tensor) -> () + %1 = "mhlo.create_token"() : () -> !mhlo.token + "mhlo.return"(%arg1) : (tensor) -> () }) : (tensor) -> tensor return %arg0 : tensor diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir index b290dcb..78617b7 100644 --- a/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -4,7 +4,7 @@ // representative op for detailed broadcast semantics. // CHECK-LABEL: @addWithoutBroadcast func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.add %arg0, %arg1 + // CHECK: mhlo.add %arg0, %arg1 %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -20,9 +20,9 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor : tensor<1xi64>} - // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]] + // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]] // CHECK-NEXT: shape.assuming_yield %[[RESULT]] // CHECK-NEXT: } // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor @@ -41,9 +41,9 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> + // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> // CHECK-NEXT: shape.assuming_yield %[[RESULT]] // CHECK-NEXT: } // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor> @@ -62,9 +62,9 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor // CHECK: shape.assuming_yield %[[RESULT]] // CHECK-NEXT: } // CHECK: return %[[FINAL_RESULT]] : tensor @@ -76,7 +76,7 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // Verifies that broadcast_dimensions validity checks are valid. // CHECK-LABEL: @dynamicNonScalarBroadcastDimensions func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK: xla_hlo.add + // CHECK: mhlo.add %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -85,7 +85,7 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor< // Verifies that broadcast_dimensions validity checks are valid. // CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor) -> tensor<1x4xf32> { - // CHECK: xla_hlo.add + // CHECK: mhlo.add %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -113,7 +113,7 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: // expansions. Tests below merely verify that the op has an expansion. // CHECK-LABEL: @andWithoutBroadcast func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: xla_hlo.and %arg0, %arg1 + // CHECK: mhlo.and %arg0, %arg1 %0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } @@ -121,7 +121,7 @@ func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x // ----- // CHECK-LABEL: @atan2WithoutBroadcast func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.atan2 %arg0, %arg1 + // CHECK: mhlo.atan2 %arg0, %arg1 %0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -129,7 +129,7 @@ func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso // ----- // CHECK-LABEL: @compareWithoutBroadcast func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> { - // CHECK: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> return %0 : tensor<4xi1> } @@ -137,7 +137,7 @@ func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten // ----- // CHECK-LABEL: @complexWithoutBroadcast func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex> { - // CHECK: "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + // CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> return %0 : tensor<4xcomplex> } @@ -145,7 +145,7 @@ func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten // ----- // CHECK-LABEL: @divideWithoutBroadcast func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.divide %arg0, %arg1 + // CHECK: mhlo.divide %arg0, %arg1 %0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -153,7 +153,7 @@ func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tens // ----- // CHECK-LABEL: @maximumWithoutBroadcast func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.maximum %arg0, %arg1 + // CHECK: mhlo.maximum %arg0, %arg1 %0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -161,7 +161,7 @@ func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten // ----- // CHECK-LABEL: @minimumWithoutBroadcast func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.minimum %arg0, %arg1 + // CHECK: mhlo.minimum %arg0, %arg1 %0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -169,7 +169,7 @@ func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten // ----- // CHECK-LABEL: @multiplyWithoutBroadcast func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.multiply %arg0, %arg1 + // CHECK: mhlo.multiply %arg0, %arg1 %0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -177,7 +177,7 @@ func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te // ----- // CHECK-LABEL: @orWithoutBroadcast func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: xla_hlo.or %arg0, %arg1 + // CHECK: mhlo.or %arg0, %arg1 %0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } @@ -185,7 +185,7 @@ func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi // ----- // CHECK-LABEL: @powerWithoutBroadcast func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.power %arg0, %arg1 + // CHECK: mhlo.power %arg0, %arg1 %0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -193,7 +193,7 @@ func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso // ----- // CHECK-LABEL: @remainderWithoutBroadcast func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.remainder %arg0, %arg1 + // CHECK: mhlo.remainder %arg0, %arg1 %0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -201,7 +201,7 @@ func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> t // ----- // CHECK-LABEL: @shift_leftWithoutBroadcast func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.shift_left %arg0, %arg1 + // CHECK: mhlo.shift_left %arg0, %arg1 %0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -209,7 +209,7 @@ func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> // ----- // CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 + // CHECK: mhlo.shift_right_arithmetic %arg0, %arg1 %0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -217,7 +217,7 @@ func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor // ----- // CHECK-LABEL: @shift_right_logicalWithoutBroadcast func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.shift_right_logical %arg0, %arg1 + // CHECK: mhlo.shift_right_logical %arg0, %arg1 %0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -225,7 +225,7 @@ func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4x // ----- // CHECK-LABEL: @subWithoutBroadcast func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.subtract %arg0, %arg1 + // CHECK: mhlo.subtract %arg0, %arg1 %0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -233,7 +233,7 @@ func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor< // ----- // CHECK-LABEL: @xorWithoutBroadcast func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: xla_hlo.xor %arg0, %arg1 + // CHECK: mhlo.xor %arg0, %arg1 %0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } diff --git a/tests/concatenate.mlir b/tests/concatenate.mlir index 179616e..aeefd68 100644 --- a/tests/concatenate.mlir +++ b/tests/concatenate.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func @single_operand // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> { - %0 = "xla_hlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %0 = "mhlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32> // CHECK-NEXT: return [[ARG]] return %0 : tensor<1x2xf32> } \ No newline at end of file diff --git a/tests/convert.mlir b/tests/convert.mlir index 783fe8a..dab395c 100644 --- a/tests/convert.mlir +++ b/tests/convert.mlir @@ -5,7 +5,7 @@ // CHECK-LABEL: func @same_type // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @same_type(%arg: tensor) -> tensor { - %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor // CHECK-NEXT: return [[ARG]] return %0 : tensor } @@ -15,8 +15,8 @@ func @same_type(%arg: tensor) -> tensor { // CHECK-LABEL: func @int_widening // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @int_widening(%arg: tensor) -> tensor { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor - %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor // CHECK-NEXT: return [[RES]] return %0 : tensor } @@ -26,8 +26,8 @@ func @int_widening(%arg: tensor) -> tensor { // CHECK-LABEL: func @int_narrowing // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @int_narrowing(%arg: tensor) -> tensor { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor - %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor // CHECK-NEXT: return [[RES]] return %0 : tensor } @@ -37,8 +37,8 @@ func @int_narrowing(%arg: tensor) -> tensor { // CHECK-LABEL: func @float_int // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @float_int(%arg: tensor) -> tensor { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor - %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor // CHECK-NEXT: return [[RES]] return %0 : tensor } @@ -48,8 +48,8 @@ func @float_int(%arg: tensor) -> tensor { // CHECK-LABEL: func @int_float // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @int_float(%arg: tensor) -> tensor { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor - %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor // CHECK-NEXT: return [[RES]] return %0 : tensor } @@ -59,8 +59,8 @@ func @int_float(%arg: tensor) -> tensor { // CHECK-LABEL: func @high_rank_tensor // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32> - %0 = "xla_hlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32> + // CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32> + %0 = "mhlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32> // CHECK-NEXT: return [[RES]] return %0 : tensor<2x3xf32> } @@ -70,9 +70,9 @@ func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> { // CHECK-LABEL: func @const_same_type func @const_same_type() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -81,9 +81,9 @@ func @const_same_type() -> tensor { // CHECK-LABEL: func @const_float_int func @const_float_int() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42.0> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42.0> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -92,9 +92,9 @@ func @const_float_int() -> tensor { // CHECK-LABEL: func @const_int_float func @const_int_float() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor - %cst = xla_hlo.constant dense<4> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor + %cst = mhlo.constant dense<4> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -103,9 +103,9 @@ func @const_int_float() -> tensor { // CHECK-LABEL: func @const_negative_int_float func @const_negative_int_float() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-4.{{0*}}e+00> : tensor - %cst = xla_hlo.constant dense<-4> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-4.{{0*}}e+00> : tensor + %cst = mhlo.constant dense<-4> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -114,9 +114,9 @@ func @const_negative_int_float() -> tensor { // CHECK-LABEL: func @const_int_bf16 func @const_int_bf16() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor - %cst = xla_hlo.constant dense<4> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor + %cst = mhlo.constant dense<4> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -125,9 +125,9 @@ func @const_int_bf16() -> tensor { // CHECK-LABEL: func @const_bf16_int func @const_bf16_int() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42.0> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42.0> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -136,9 +136,9 @@ func @const_bf16_int() -> tensor { // CHECK-LABEL: func @const_int_narrowing func @const_int_narrowing() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -147,9 +147,9 @@ func @const_int_narrowing() -> tensor { // CHECK-LABEL: func @const_int_widening func @const_int_widening() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -158,9 +158,9 @@ func @const_int_widening() -> tensor { // CHECK-LABEL: func @const_negative_int_widening func @const_negative_int_widening() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-42> : tensor - %cst = xla_hlo.constant dense<-42> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor + %cst = mhlo.constant dense<-42> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -169,9 +169,9 @@ func @const_negative_int_widening() -> tensor { // CHECK-LABEL: func @const_float_narrowing func @const_float_narrowing() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor - %cst = xla_hlo.constant dense<4.2> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor + %cst = mhlo.constant dense<4.2> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -180,9 +180,9 @@ func @const_float_narrowing() -> tensor { // CHECK-LABEL: func @const_f32_bf16 func @const_f32_bf16() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+01> : tensor - %cst = xla_hlo.constant dense<42.0> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+01> : tensor + %cst = mhlo.constant dense<42.0> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -191,9 +191,9 @@ func @const_f32_bf16() -> tensor { // CHECK-LABEL: func @const_bf16_f64 func @const_bf16_f64() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor - %cst = xla_hlo.constant dense<4.2> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.187500e+00> : tensor + %cst = mhlo.constant dense<4.2> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -202,9 +202,9 @@ func @const_bf16_f64() -> tensor { // CHECK-LABEL: func @const_bf16_int func @const_bf16_int() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42.0> : tensor - %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42.0> : tensor + %0 = "mhlo.convert"(%cst) : (tensor) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -214,11 +214,11 @@ func @const_bf16_int() -> tensor { // CHECK-LABEL: func @const_high_rank_tensor func @const_high_rank_tensor() -> tensor<2x3xi32> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[ + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[ // CHECK-SAME: [1, 2, 3], [4, 5, 6] // CHECK-SAME: ]> : tensor<2x3xi32> - %cst = xla_hlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> - %0 = "xla_hlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32> + %cst = mhlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> + %0 = "mhlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32> // CHECK-NEXT: return [[CST]] return %0 : tensor<2x3xi32> } diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index b13dd27..0db595c 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -4,7 +4,7 @@ // BOTH-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.exponential"(%tensor_operand) + %tensor_result = "mhlo.exponential"(%tensor_operand) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} @@ -28,11 +28,11 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> { // BOTH-LABEL: func @func_op_long func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> - %2 = xla_hlo.add %arg0, %1 : tensor<4xf32> - %3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> - %4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32> - %5 = xla_hlo.multiply %2, %4 : tensor<4xf32> + %1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32> + %2 = mhlo.add %arg0, %1 : tensor<4xf32> + %3 = mhlo.minimum %arg0, %arg1 : tensor<4xf32> + %4 = mhlo.subtract %arg1, %3 : tensor<4xf32> + %5 = mhlo.multiply %2, %4 : tensor<4xf32> return %5 : tensor<4xf32> } // PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) @@ -65,12 +65,12 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, // BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32> %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> - %sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2) + %sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> - %tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier) + %tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> @@ -86,7 +86,7 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, // BOTH-LABEL: func @copy func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.copy"(%tensor_operand) + %tensor_result = "mhlo.copy"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -98,7 +98,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @exp func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.exponential"(%tensor_operand) + %tensor_result = "mhlo.exponential"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -110,7 +110,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @log func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.log"(%tensor_operand) + %tensor_result = "mhlo.log"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -125,7 +125,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %tensor_pred = tensor_load %pred : memref<2x2xi1> %tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32> - %tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) + %tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -138,7 +138,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { %tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32> - %tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs) + %tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs) {comparison_direction = "EQ"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> // BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} @@ -151,7 +151,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x // BOTH-LABEL: func @broadcast func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { %tensor_operand = tensor_load %operand : memref<5xf32> - %tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand) + %tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<10x5xf32> // BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} @@ -170,7 +170,7 @@ func @dyn_broadcast(%operand: memref) { // BOTH-SAME: (%[[OPERAND:.*]]: memref) %tensor_operand = tensor_load %operand : memref %shape = call @external_func() : () -> tensor<3xi64> - %tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { + %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> } : (tensor, tensor<3xi64>) -> tensor // BOTH: %[[SHAPE:.*]] = call @external_func() @@ -226,7 +226,7 @@ func @complex(%real: memref<2x2xf32>, %result: memref<2x2xcomplex>) { %tensor_real = tensor_load %real : memref<2x2xf32> %tensor_imag = tensor_load %imag : memref<2x2xf32> - %tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag) + %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex> // BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xcomplex> @@ -238,7 +238,7 @@ func @complex(%real: memref<2x2xf32>, // BOTH-LABEL: func @real func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> - %tensor_result = "xla_hlo.real"(%tensor_operand) + %tensor_result = "mhlo.real"(%tensor_operand) : (tensor<2x2xcomplex>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -250,7 +250,7 @@ func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @imag func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> - %tensor_result = "xla_hlo.imag"(%tensor_operand) + %tensor_result = "mhlo.imag"(%tensor_operand) : (tensor<2x2xcomplex>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -261,7 +261,7 @@ func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @iota func @iota(%result: memref<10xi32>) { - %tensor_result = "xla_hlo.iota"() + %tensor_result = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<10xi32> // BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} tensor_store %tensor_result, %result : memref<10xi32> @@ -273,7 +273,7 @@ func @iota(%result: memref<10xi32>) { // BOTH-LABEL: func @abs func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.abs"(%tensor_operand) + %tensor_result = "mhlo.abs"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -285,7 +285,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @ceil func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.ceil"(%tensor_operand) + %tensor_result = "mhlo.ceil"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -297,7 +297,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @convert func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.convert"(%tensor_operand) + %tensor_result = "mhlo.convert"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) // BOTH-NOT: tensor_store @@ -310,7 +310,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @cos func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.cosine"(%tensor_operand) + %tensor_result = "mhlo.cosine"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -322,7 +322,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @neg func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.negate"(%tensor_operand) + %tensor_result = "mhlo.negate"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -334,7 +334,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @rsqrt func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.rsqrt"(%tensor_operand) + %tensor_result = "mhlo.rsqrt"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -346,7 +346,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @sign func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.sign"(%tensor_operand) + %tensor_result = "mhlo.sign"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -358,7 +358,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @sqrt func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.sqrt"(%tensor_operand) + %tensor_result = "mhlo.sqrt"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -370,7 +370,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // BOTH-LABEL: func @tanh func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> - %tensor_result = "xla_hlo.tanh"(%tensor_operand) + %tensor_result = "mhlo.tanh"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -383,7 +383,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32> - %tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs) + %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> @@ -395,7 +395,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x // Dynamic shape binary element-wise operation. // BOTH-LABEL: func @add_dyn func @add_dyn(%lhs: tensor, %rhs: tensor) { - %result = "xla_hlo.add"(%lhs, %rhs) + %result = "mhlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor // BOTH: %[[C0:.*]] = constant 0 : index // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref @@ -420,7 +420,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) { // Dynamic shape unary element-wise operation. // BOTH-LABEL: func @tanh_dyn func @tanh_dyn(%arg0: tensor) { - %result = "xla_hlo.tanh"(%arg0) + %result = "mhlo.tanh"(%arg0) : (tensor) -> tensor // BOTH: %[[C0:.*]] = constant 0 : index // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref @@ -448,7 +448,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // BOTH-NEXT: %[[ALLOC:.*]] = alloc // BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () - %dot = "xla_hlo.dot"(%arg0, %arg0) + %dot = "mhlo.dot"(%arg0, %arg0) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> // PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]]) // ESC: return %[[ALLOC]] @@ -466,7 +466,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor // BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> // BOTH-SAME: rhs_dilation = dense<[1, 2]> // BOTH-SAME: window_strides = dense<[2, 1]> - %out = "xla_hlo.convolution"(%filter, %input) { + %out = "mhlo.convolution"(%filter, %input) { batch_group_count = 1 : i64, dimension_numbers = { input_batch_dimension = 0 : i64, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index b633f17..320ce06 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -10,7 +10,7 @@ func @float_add(%lhs: tensor<2x2xf32>, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32 // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = addf %[[ARG0]], %[[ARG1]] // CHECK: linalg.yield %[[RESULT]] - %0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xf32>, + %0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -22,7 +22,7 @@ func @integer_add(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: addi - %0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -34,7 +34,7 @@ func @float_mul(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: mulf - %0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>, + %0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -46,7 +46,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: muli - %0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -58,7 +58,7 @@ func @float_remainder(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: remf - %0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>, + %0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -70,7 +70,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: remi_signed - %0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -79,7 +79,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>, // CHECK-LABEL: func @float_rsqrt func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { - %tensor_result = "xla_hlo.rsqrt"(%operand) + %tensor_result = "mhlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK: linalg.generic // CHECK: rsqrt @@ -93,7 +93,7 @@ func @float_sub(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: subf - %0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, + %0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -105,7 +105,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: subi - %0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -116,7 +116,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>, func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: absf - %0 = "xla_hlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -126,7 +126,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: exp - %0 = "xla_hlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -136,7 +136,7 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: log - %0 = "xla_hlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -146,7 +146,7 @@ func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: ceilf - %0 = "xla_hlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -156,7 +156,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: negf - %0 = "xla_hlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -166,7 +166,7 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: tanh - %0 = "xla_hlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -177,7 +177,7 @@ func @integer_and(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: and - %0 = "xla_hlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, + %0 = "mhlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -187,7 +187,7 @@ func @integer_and(%lhs: tensor<2x2xi32>, // CHECK-LABEL: func @float_cmp func @float_cmp(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { - %0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} + %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> return %0 : tensor<2x2xi1> } @@ -201,7 +201,7 @@ func @float_cmp(%lhs: tensor<2x2xf32>, // CHECK-LABEL: func @int_cmp func @int_cmp(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi1> { - %0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"} + %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"} : (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>) return %0 : tensor<2x2xi1> } @@ -216,7 +216,7 @@ func @int_cmp(%lhs: tensor<2x2xi32>, func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: cos - %0 = "xla_hlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -226,7 +226,7 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: sin - %0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -235,7 +235,7 @@ func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @copy // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { - %0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>) + %0 = "mhlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>) return %0 : tensor<2x4x8xf32> } // CHECK: return [[ARG]] : tensor<2x4x8xf32> @@ -245,7 +245,7 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { // CHECK-LABEL: func @select func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = "xla_hlo.select"(%pred, %lhs, %rhs) + %0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) return %0 : tensor<2x2xf32> } @@ -260,7 +260,7 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { - %0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor) -> tensor<4x2x1xf32> + %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor) -> tensor<4x2x1xf32> return %0: tensor<4x2x1xf32> } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -273,7 +273,7 @@ func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> // CHECK-LABEL: func @broadcast func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { - %0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> + %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> return %0: tensor<4x2x1x4x?x16xf32> } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -286,7 +286,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> // CHECK-LABEL: func @broadcast_in_dim func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { - %0 = "xla_hlo.broadcast_in_dim"(%operand) + %0 = "mhlo.broadcast_in_dim"(%operand) {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>} : (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> return %0 : tensor<7x10x6x4x5xf32> @@ -302,7 +302,7 @@ func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { // CHECK-LABEL: func @broadcast_in_dim_with_one_to_one func @broadcast_in_dim_with_one_to_one( %operand: tensor<1xf32>) -> tensor<1x5xf32> { - %0 = "xla_hlo.broadcast_in_dim"(%operand) + %0 = "mhlo.broadcast_in_dim"(%operand) {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<1x5xf32> return %0 : tensor<1x5xf32> @@ -317,7 +317,7 @@ func @broadcast_in_dim_with_one_to_one( // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func @broadcast_scalar(%operand: tensor) -> tensor<7x10x6xf32> { - %0 = "xla_hlo.broadcast_in_dim"(%operand) + %0 = "mhlo.broadcast_in_dim"(%operand) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<7x10x6xf32> return %0 : tensor<7x10x6xf32> @@ -332,7 +332,7 @@ func @broadcast_scalar(%operand: tensor) -> tensor<7x10x6xf32> { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func @transpose func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> return %0 : tensor<3x2x5x9xi32> } @@ -344,7 +344,7 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-LABEL: func @reshape_3D_2D func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> + %0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> return %0 : tensor<12x42xi32> } // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] @@ -355,7 +355,7 @@ func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> // CHECK-LABEL: func @reshape_4D_2D func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32> + %0 = "mhlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32> return %0 : tensor<12x42xi32> } // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] @@ -366,7 +366,7 @@ func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> { // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // CHECK-LABEL: func @reshape_2D_4D func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32> + %0 = "mhlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32> return %0 : tensor<12x1x42x1xi32> } // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] @@ -375,7 +375,7 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { // CHECK-LABEL: func @minf func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = "xla_hlo.minimum"(%lhs, %rhs) + %0 = "mhlo.minimum"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } @@ -389,7 +389,7 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @maxi func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "xla_hlo.maximum"(%lhs, %rhs) + %0 = "mhlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -404,7 +404,7 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()> // CHECK-LABEL: func @add_scalar func @add_scalar(%lhs: tensor, %rhs: tensor) -> tensor { - %0 = "xla_hlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor return %0 : tensor } // CHECK: linalg.generic @@ -417,7 +417,7 @@ func @add_scalar(%lhs: tensor, %rhs: tensor) -> tensor { func @reshape_collapse_single_dim (%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32> return %0 : tensor<1x784xf32> } // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> @@ -428,7 +428,7 @@ func @reshape_collapse_single_dim // ----- func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> return %0 : tensor<2x4x3xf32> } // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> @@ -440,7 +440,7 @@ func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> { // ----- func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32> return %0 : tensor<2x4x2xf32> } // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)> @@ -451,7 +451,7 @@ func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> { // ----- func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32> return %0 : tensor<1x4x2xf32> } // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -462,7 +462,7 @@ func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> { func @reshape_multiple_collapse (%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> + %0 = "mhlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> return %0 : tensor<1x4x5x6xf32> } // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)> @@ -476,7 +476,7 @@ func @reshape_multiple_collapse // CHECK-LABEL: func @convert_i32_to_f32 func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> + %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> return %result : tensor<2x2xf32> } // CHECK: linalg.generic @@ -488,7 +488,7 @@ func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @convert_i16_to_i32 func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32> + %result = "mhlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32> return %result : tensor<2x2xi32> } // CHECK: linalg.generic @@ -500,7 +500,7 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> { // CHECK-LABEL: func @convert_i32_to_i16 func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16> + %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16> return %result : tensor<2x2xi16> } // CHECK: linalg.generic @@ -512,7 +512,7 @@ func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> { // CHECK-LABEL: func @convert_f32_to_f64 func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64> + %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64> return %result : tensor<2x2xf64> } // CHECK: linalg.generic @@ -524,7 +524,7 @@ func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> { // CHECK-LABEL: func @convert_f64_to_f32 func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32> + %result = "mhlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32> return %result : tensor<2x2xf32> } // CHECK: linalg.generic @@ -536,7 +536,7 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { // CHECK-LABEL: func @convert_f32_to_i32 func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { - %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> + %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> return %result : tensor<2x2xi32> } // CHECK: linalg.generic @@ -550,7 +550,7 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reverse func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { - %result = "xla_hlo.reverse"(%input) { + %result = "mhlo.reverse"(%input) { dimensions = dense<1> : tensor<1xi64> } : (tensor<2x3xf32>) -> tensor<2x3xf32> return %result : tensor<2x3xf32> diff --git a/tests/inlining.mlir b/tests/inlining.mlir index 7b1bbf5..f4ed563 100644 --- a/tests/inlining.mlir +++ b/tests/inlining.mlir @@ -1,28 +1,28 @@ // RUN: mlir-hlo-opt %s -inline | FileCheck %s -// Test case: Basic test of inlining into xla_hlo.while. +// Test case: Basic test of inlining into mhlo.while. // CHECK-LABEL: func @caller -// CHECK: "xla_hlo.while"{{.*}}( { +// CHECK: "mhlo.while"{{.*}}( { // CHECK: }, { -// CHECK: "xla_hlo.exponential" +// CHECK: "mhlo.exponential" // CHECK: }) // CHECK-LABEL: func @callee func @caller(%arg0: tensor, %pred: tensor) -> tensor { - %0 = "xla_hlo.while"(%arg0) ( { + %0 = "mhlo.while"(%arg0) ( { ^entry(%unused: tensor): - "xla_hlo.return"(%pred) : (tensor) -> () + "mhlo.return"(%pred) : (tensor) -> () }, { ^entry(%0: tensor): %1 = call @callee(%0) : (tensor) -> (tensor) - "xla_hlo.return"(%1) : (tensor) -> () + "mhlo.return"(%1) : (tensor) -> () } ) : (tensor) -> (tensor) return %0 : tensor } func @callee(%arg0: tensor) -> tensor { - %0 = "xla_hlo.exponential"(%arg0) : (tensor) -> tensor + %0 = "mhlo.exponential"(%arg0) : (tensor) -> tensor return %0 : tensor } diff --git a/tests/legalize-control-flow.mlir b/tests/legalize-control-flow.mlir index 4096b06..1d5faea 100644 --- a/tests/legalize-control-flow.mlir +++ b/tests/legalize-control-flow.mlir @@ -4,21 +4,21 @@ func @while(%arg0: tensor) -> tensor { //CHECK: br ^bb1(%arg0 : tensor) //CHECK: ^bb1([[VAL0:%.+]]: tensor): - //CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]]) + //CHECK: [[VAL1:%.+]] = "mhlo.compare"([[VAL0]], [[VAL0]]) //CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor //CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor), ^bb3([[VAL0]] : tensor) //CHECK: ^bb2([[VAL3:%.+]]: tensor): - //CHECK: [[VAL4:%.+]] = xla_hlo.add [[VAL3]], [[VAL3]] + //CHECK: [[VAL4:%.+]] = mhlo.add [[VAL3]], [[VAL3]] //CHECK: br ^bb1([[VAL4]] : tensor) //CHECK: ^bb3([[VAL5:%.+]]: tensor): - %0 = "xla_hlo.while"(%arg0) ( { + %0 = "mhlo.while"(%arg0) ( { ^bb0(%arg1: tensor): - %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): - %1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) : (tensor) -> tensor // CHECK-NEXT: return [[VAL5]] @@ -30,27 +30,27 @@ func @conditional(%arg0: tensor) -> tensor { // CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor %cst = constant dense<1.000000e+01> : tensor - // CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - %0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + %0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor // CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor), ^bb2(%arg0 : tensor) - %1 = "xla_hlo.if"(%0, %arg0, %arg0) ( { + %1 = "mhlo.if"(%0, %arg0, %arg0) ( { ^bb0(%arg1: tensor): // CHECK: ^bb1([[VAL2:%.+]]: tensor): - // CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor) -> tensor + // CHECK: [[VAL3:%.+]] = "mhlo.log"([[VAL2]]) : (tensor) -> tensor // CHECK: br ^bb3([[VAL3]] : tensor) - %2 = "xla_hlo.log"(%arg1) : (tensor) -> tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = "mhlo.log"(%arg1) : (tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () }, { ^bb0(%arg1: tensor): // CHECK: ^bb2([[VAL4:%.+]]: tensor): - // CHECK: [[VAL5:%.+]] = "xla_hlo.exponential"([[VAL4]]) : (tensor) -> tensor + // CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]]) : (tensor) -> tensor // CHECK: br ^bb3([[VAL5]] : tensor) - %2 = "xla_hlo.exponential"(%arg1) : (tensor) -> tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = "mhlo.exponential"(%arg1) : (tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () }) : (tensor, tensor, tensor) -> tensor // CHECK: ^bb3([[VAL6:%.+]]: tensor): @@ -62,27 +62,27 @@ func @conditional(%arg0: tensor) -> tensor { func @while_with_multiple_blocks_in_body(%arg0: tensor) -> tensor { // CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor) // CHECK: ^[[COND_ENTRY]](%0: tensor): - // CHECK: %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK: %2 = extract_element %1[] : tensor // CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor), ^[[EXIT:.+]](%0 : tensor) // CHECK: ^[[BODY_ENTRY]](%3: tensor): // CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor) // CHECK: ^[[BODY_SUCC]](%4: tensor): - // CHECK: %5 = xla_hlo.add %4, %4 : tensor + // CHECK: %5 = mhlo.add %4, %4 : tensor // CHECK: br ^[[COND_ENTRY]](%5 : tensor) // CHECK: ^[[EXIT]](%6: tensor): // CHECK: return %6 : tensor // CHECK: } - %0 = "xla_hlo.while"(%arg0) ( { + %0 = "mhlo.while"(%arg0) ( { ^cond_entry(%arg1: tensor): - %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^body_entry(%arg1: tensor): br ^body_succ(%arg1: tensor) ^body_succ(%0: tensor): - %1 = xla_hlo.add %0, %0 : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %0, %0 : tensor + "mhlo.return"(%1) : (tensor) -> () }) : (tensor) -> tensor return %0 : tensor @@ -94,7 +94,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor) -> tensor { // CHECK: ^[[COND_ENTRY]](%0: tensor): // CHECK: br ^[[COND_SUCC:.+]](%0 : tensor) // CHECK: ^[[COND_SUCC]](%1: tensor): - // CHECK: %2 = "xla_hlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK: %3 = extract_element %2[] : tensor // CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor), ^[[EXIT:.+]](%0 : tensor) // CHECK: ^[[BODY_ENTRY]](%4: tensor): @@ -102,15 +102,15 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor) -> tensor { // CHECK: ^[[EXIT]](%5: tensor): // CHECK: return %5 : tensor // CHECK: } - %0 = "xla_hlo.while"(%arg0) ( { + %0 = "mhlo.while"(%arg0) ( { ^cond_entry(%arg1: tensor): br ^cond_succ(%arg1: tensor) ^cond_succ(%0: tensor): - %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^body_entry(%arg1: tensor): - "xla_hlo.return"(%arg1) : (tensor) -> () + "mhlo.return"(%arg1) : (tensor) -> () }) : (tensor) -> tensor return %0 : tensor @@ -123,24 +123,24 @@ func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, % // CHECK: ^[[THEN_ENTRY]](%1: tensor): // CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor) // CHECK: ^[[THEN_SUCC]](%2: tensor): - // CHECK: %3 = "xla_hlo.log"(%2) : (tensor) -> tensor + // CHECK: %3 = "mhlo.log"(%2) : (tensor) -> tensor // CHECK: br ^[[EXIT:.+]](%3 : tensor) // CHECK: ^[[ELSE_ENTRY]](%4: tensor): - // CHECK: %5 = "xla_hlo.exponential"(%4) : (tensor) -> tensor + // CHECK: %5 = "mhlo.exponential"(%4) : (tensor) -> tensor // CHECK: br ^[[EXIT]](%5 : tensor) // CHECK: ^[[EXIT]](%6: tensor): // CHECK: return %6 : tensor // CHECK: } - %1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( { + %1 = "mhlo.if"(%pred, %arg0, %arg1) ( { ^then_entry(%arg2: tensor): br ^then_succ(%arg2: tensor) ^then_succ(%0: tensor): - %2 = "xla_hlo.log"(%0) : (tensor) -> tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = "mhlo.log"(%0) : (tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () }, { ^else_entry(%arg2: tensor): - %2 = "xla_hlo.exponential"(%arg2) : (tensor) -> tensor - "xla_hlo.return"(%2) : (tensor) -> () + %2 = "mhlo.exponential"(%arg2) : (tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () }) : (tensor, tensor, tensor) -> tensor return %1 : tensor } diff --git a/tests/legalize-to-std.mlir b/tests/legalize-to-std.mlir index c4153b2..774f926 100644 --- a/tests/legalize-to-std.mlir +++ b/tests/legalize-to-std.mlir @@ -3,19 +3,19 @@ // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32> - %1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32> - %2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32> - %3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32> - %4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %4 : tensor<4xf32> return %4 : tensor<4xf32> @@ -24,19 +24,19 @@ func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf // CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32> - %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32> - %1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32> - %2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32> - %3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32> - %4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: return %4 : tensor<4xi32> return %4 : tensor<4xi32> @@ -45,17 +45,17 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32 // CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> - %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32> - %1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32> - %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32> - %3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32> - %4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32> - %5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> } @@ -63,28 +63,28 @@ func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi // CHECK-LABEL: func @compare_float func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32> - %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32> - %1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32> - %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32> - %3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32> - %4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32> - %5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> } // CHECK-LABEL: func @int_constant func @int_constant() -> (tensor, tensor<2x3xi32>, tensor<2x3xi32>) { // CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor - %0 = "xla_hlo.constant"() {value = dense<0> : tensor} : () -> (tensor) + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> (tensor) // CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32> - %1 = "xla_hlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) + %1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) // CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32> - %2 = "xla_hlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) + %2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) // CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor, tensor<2x3xi32>, tensor<2x3xi32> return %0, %1, %2: tensor, tensor<2x3xi32>, tensor<2x3xi32> } @@ -92,11 +92,11 @@ func @int_constant() -> (tensor, tensor<2x3xi32>, tensor<2x3xi32>) { // CHECK-LABEL: func @float_constant func @float_constant() -> (tensor, tensor<2x3xf32>, tensor<2x3xf32>) { // CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor - %0 = "xla_hlo.constant"() {value = dense<0.0> : tensor} : () -> (tensor) + %0 = "mhlo.constant"() {value = dense<0.0> : tensor} : () -> (tensor) // CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32> - %1 = "xla_hlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) + %1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) // CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32> - %2 = "xla_hlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) + %2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) // CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor, tensor<2x3xf32>, tensor<2x3xf32> return %0, %1, %2: tensor, tensor<2x3xf32>, tensor<2x3xf32> } @@ -105,7 +105,7 @@ func @float_constant() -> (tensor, tensor<2x3xf32>, tensor<2x3xf32>) { // CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> { func @iota.const.1() -> tensor<4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> // CHECK-NEXT: return %[[CST]] : tensor<4xi32> return %0 : tensor<4xi32> } @@ -113,7 +113,7 @@ func @iota.const.1() -> tensor<4xi32> { // CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> { func @iota.const.2() -> tensor<2x4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> return %0 : tensor<2x4xi32> } @@ -121,7 +121,7 @@ func @iota.const.2() -> tensor<2x4xi32> { // CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> { func @iota.const.3() -> tensor<2x4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> return %0 : tensor<2x4xi32> } @@ -129,7 +129,7 @@ func @iota.const.3() -> tensor<2x4xi32> { // CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> { func @iota.const.4() -> tensor<2x3x4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32> } @@ -137,7 +137,7 @@ func @iota.const.4() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> { func @iota.const.5() -> tensor<2x3x4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32> } @@ -145,7 +145,7 @@ func @iota.const.5() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> { func @iota.const.6() -> tensor<2x3x4xi32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32> - %0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32> } @@ -153,7 +153,7 @@ func @iota.const.6() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.f32 func @iota.const.f32() -> tensor<4xf32> { // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> // CHECK-NEXT: return %[[CST]] : tensor<4xf32> return %0 : tensor<4xf32> } @@ -161,7 +161,7 @@ func @iota.const.f32() -> tensor<4xf32> { // CHECK-LABEL: func @iota.const.f64 func @iota.const.f64() -> tensor<4xf64> { // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64> // CHECK-NEXT: return %[[CST]] : tensor<4xf64> return %0 : tensor<4xf64> } @@ -169,7 +169,7 @@ func @iota.const.f64() -> tensor<4xf64> { // CHECK-LABEL: func @iota.const.bf16 func @iota.const.bf16() -> tensor<4xbf16> { // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16> - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16> // CHECK-NEXT: return %[[CST]] : tensor<4xbf16> return %0 : tensor<4xbf16> } @@ -178,8 +178,8 @@ func @iota.const.bf16() -> tensor<4xbf16> { func @iota.const.complex.f32() -> tensor<4xcomplex> { // CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> // CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32> - // CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]]) - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> + // CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]]) + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex> return %0 : tensor<4xcomplex> } @@ -188,8 +188,8 @@ func @iota.const.complex.f32() -> tensor<4xcomplex> { func @iota.const.complex.f64() -> tensor<4xcomplex> { // CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> // CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64> - // CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]]) - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> + // CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]]) + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex> return %0 : tensor<4xcomplex> } diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 11cecde..e793e2a 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -396,9 +396,9 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m "xla_lhlo.fusion"() ( { %0 = tensor_load %input1 : memref<10xf32> %1 = tensor_load %input2 : memref<10xf32> - %2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %3 = tensor_load %input3 : memref<10xf32> - %4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> tensor_store %4, %out : memref<10xf32> "xla_lhlo.terminator"() : () -> () } ) : () -> () @@ -803,15 +803,15 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () { "xla_lhlo.all_reduce"(%arg0, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): - %max = xla_hlo.maximum %lhs, %rhs : tensor - "xla_hlo.return"(%max) : (tensor) -> () + %max = mhlo.maximum %lhs, %rhs : tensor + "mhlo.return"(%max) : (tensor) -> () }) { replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> () "xla_lhlo.all_reduce"(%arg0, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): - %max = xla_hlo.maximum %lhs, %rhs : tensor - "xla_hlo.return"(%max) : (tensor) -> () + %max = mhlo.maximum %lhs, %rhs : tensor + "mhlo.return"(%max) : (tensor) -> () }) { replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, @@ -958,8 +958,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32 %updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () { "xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): // no predecessors - %add = xla_hlo.add %lhs, %rhs : tensor - "xla_hlo.return"(%add) : (tensor) -> () + %add = mhlo.add %lhs, %rhs : tensor + "mhlo.return"(%add) : (tensor) -> () }) { scatter_dimension_numbers = { update_window_dims = dense<[1]> : tensor<1xi64>, @@ -979,8 +979,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32 func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () { "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ ^bb0(%a: tensor, %b: tensor): - %c = xla_hlo.add %a, %b : tensor - "xla_hlo.return"(%c) : (tensor) -> () + %c = mhlo.add %a, %b : tensor + "mhlo.return"(%c) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> () return } @@ -991,8 +991,8 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref // expected-error@+1{{requires the same shape for all operands}} "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ ^bb0(%a: tensor, %b: tensor): - %c = xla_hlo.add %a, %b : tensor - "xla_hlo.return"(%c) : (tensor) -> () + %c = mhlo.add %a, %b : tensor + "mhlo.return"(%c) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> () return } @@ -1012,8 +1012,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): - %7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () return } @@ -1025,8 +1025,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): - %7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () return } @@ -1038,8 +1038,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): - %7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () return } diff --git a/tests/lower-complex.mlir b/tests/lower-complex.mlir index 696e225..4db4d80 100644 --- a/tests/lower-complex.mlir +++ b/tests/lower-complex.mlir @@ -2,14 +2,14 @@ // CHECK-LABEL: @add func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3 - %4 = "xla_hlo.add"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + // CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3 + %4 = "mhlo.add"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "mhlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) // CHECK: return [[VAL0]], [[VAL1]] return %5, %6 : tensor<2xf32>, tensor<2xf32> @@ -17,14 +17,14 @@ func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % // CHECK-LABEL: @add_unranked func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3 - %4 = "xla_hlo.add"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + // CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3 + %4 = "mhlo.add"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "mhlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) // CHECK: return [[VAL0]], [[VAL1]] return %5, %6 : tensor<*xf32>, tensor<*xf32> @@ -32,14 +32,14 @@ func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< // CHECK-LABEL: @sub func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3 - %4 = "xla_hlo.subtract"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + // CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3 + %4 = "mhlo.subtract"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "mhlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) // CHECK: return [[VAL0]], [[VAL1]] return %5, %6 : tensor<2xf32>, tensor<2xf32> @@ -47,14 +47,14 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % // CHECK-LABEL: @sub_unranked func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3 - %4 = "xla_hlo.subtract"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + // CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3 + %4 = "mhlo.subtract"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "mhlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) // CHECK: return [[VAL0]], [[VAL1]] return %5, %6 : tensor<*xf32>, tensor<*xf32> @@ -62,18 +62,18 @@ func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< // CHECK-LABEL: @mul func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3 - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] - %4 = "xla_hlo.multiply"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + // CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3 + // CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3 + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]] + %4 = "mhlo.multiply"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "mhlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) // CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32> return %5, %6 : tensor<2xf32>, tensor<2xf32> @@ -81,18 +81,18 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % // CHECK-LABEL: @mul_unranked func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3 - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] - %4 = "xla_hlo.multiply"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + // CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3 + // CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3 + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]] + %4 = "mhlo.multiply"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + %5 = "mhlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "mhlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) // CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32> return %5, %6 : tensor<*xf32>, tensor<*xf32> @@ -100,36 +100,36 @@ func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< // CHECK-LABEL: @div func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) + // CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3) // Compute the numerator's real component: // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] + // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]] + // CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]] // Compute the real valued denominator as rhs * con(rhs): // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]] // Compute the numerator's imaginary component: // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2 - // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]] - // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] + // CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]] + // CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]] // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]] - // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]] - %4 = "xla_hlo.divide"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + // CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]] + // CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]] + %4 = "mhlo.divide"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %5 = "mhlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "mhlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) // CHECK: return [[VAL10]], [[VAL11]] return %5, %6 : tensor<2xf32>, tensor<2xf32> @@ -139,36 +139,36 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % // CHECK-LABEL: @div_unranked func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) + // CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3) // Compute the numerator's real component: // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]] - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] + // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]] + // CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]] // Compute the real valued denominator as rhs * con(rhs): // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]] // Compute the numerator's imaginary component: // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2 - // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]] - // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] + // CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]] + // CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]] // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]] - // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]] - %4 = "xla_hlo.divide"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + // CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]] + // CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]] + %4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %5 = "mhlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "mhlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) // CHECK: return [[VAL10]], [[VAL11]] return %5, %6 : tensor<*xf32>, tensor<*xf32> @@ -176,14 +176,14 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< // CHECK-LABEL: @abs func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { - %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg0 - // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg1 - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]]) - %1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) - %2 = "xla_hlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) + // CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg0 + // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1 + // CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]]) + %1 = "mhlo.abs"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) + %2 = "mhlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) // CHECK: return [[VAL3]] return %2 : tensor<2xf32> @@ -191,16 +191,16 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { // CHECK-LABEL: @exp func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0) - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1) - // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1) - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]] - %1 = "xla_hlo.exponential"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) - %2 = "xla_hlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) - %3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) + // CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0) + // CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1) + // CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1) + // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]] + %1 = "mhlo.exponential"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) + %2 = "mhlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %3 = "mhlo.imag"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) // CHECK: return [[VAL3]], [[VAL4]] return %2, %3 : tensor<2xf32>, tensor<2xf32> @@ -208,16 +208,16 @@ func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tenso // CHECK-LABEL: @exp_unranked func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0) - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1) - // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1) - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]] - %1 = "xla_hlo.exponential"(%0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) - %2 = "xla_hlo.real"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) + // CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0) + // CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1) + // CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1) + // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]] + %1 = "mhlo.exponential"(%0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) + %2 = "mhlo.real"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %3 = "mhlo.imag"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) // CHECK: return [[VAL3]], [[VAL4]] return %2, %3 : tensor<*xf32>, tensor<*xf32> diff --git a/tests/lower-general-dot.mlir b/tests/lower-general-dot.mlir index b54a0aa..3ee23da 100644 --- a/tests/lower-general-dot.mlir +++ b/tests/lower-general-dot.mlir @@ -2,10 +2,10 @@ // CHECK-LABEL: @testDebatch1 func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> { - // CHECK-DAG: [[R0:%.+]] = "xla_hlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32> - // CHECK-DAG: [[R1:%.+]] = "xla_hlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - // CHECK: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32> - %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32> + // CHECK-DAG: [[R0:%.+]] = "mhlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32> + // CHECK-DAG: [[R1:%.+]] = "mhlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + // CHECK: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32> + %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32> return %0 : tensor<1x1x3xf32> } @@ -14,13 +14,13 @@ func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1 // CHECK-LABEL: @testDebatch2 func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> { - // CHECK-DAG: [[R0:%.+]] = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> - // CHECK-DAG: [[R1:%.+]] = "xla_hlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32> - // CHECK-DAG: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32> - // CHECK-DAG: [[R3:%.+]] = "xla_hlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32> - // CHECK: [[R4:%.+]] = "xla_hlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + // CHECK-DAG: [[R0:%.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> + // CHECK-DAG: [[R1:%.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32> + // CHECK-DAG: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32> + // CHECK-DAG: [[R3:%.+]] = "mhlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32> + // CHECK: [[R4:%.+]] = "mhlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32> - %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32> + %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32> return %0 : tensor<3x1x1xf32> } @@ -28,8 +28,8 @@ func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3 // CHECK-LABEL: @testBatchPassthrough func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> { - // CHECK-NEXT: "xla_hlo.dot_general"(%arg0, %arg1) - %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32> + // CHECK-NEXT: "mhlo.dot_general"(%arg0, %arg1) + %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32> return %0 : tensor<3x2x1xf32> } diff --git a/tests/materialize-broadcasts.mlir b/tests/materialize-broadcasts.mlir index bfe1fe3..4fd8b3d 100644 --- a/tests/materialize-broadcasts.mlir +++ b/tests/materialize-broadcasts.mlir @@ -3,9 +3,9 @@ // CHECK-LABEL: @clampBroadcast // CHECK-SAME: (%[[MIN:.+]]: tensor, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor) func @clampBroadcast(%min: tensor, %value: tensor<4xf32>, %max: tensor) -> tensor<4xf32> { - // CHECK-DAG: %[[MIN_BC:.+]] = "xla_hlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> - // CHECK-DAG: %[[MAX_BC:.+]] = "xla_hlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> - // CHECK: "xla_hlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-DAG: %[[MIN_BC:.+]] = "mhlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> + // CHECK-DAG: %[[MAX_BC:.+]] = "mhlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> + // CHECK: "mhlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "mhlo.clamp"(%min, %value, %max) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> return %0 : tensor<4xf32> } diff --git a/tests/ops.mlir b/tests/ops.mlir index 727e747..b46827b 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -3,19 +3,19 @@ // Tests for types, ops with custom constraints, verifiers, printer or parser // methods. -// CHECK-LABEL: func @token_type() -> !xla_hlo.token -func @token_type() -> !xla_hlo.token +// CHECK-LABEL: func @token_type() -> !mhlo.token +func @token_type() -> !mhlo.token // ----- -// expected-error@+1 {{unknown xla_hlo type: foobar}} -func @invalid_type() -> !xla_hlo.foobar +// expected-error@+1 {{unknown mhlo type: foobar}} +func @invalid_type() -> !mhlo.foobar // ----- // CHECK-LABEL: func @alltoall func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { - %0 = "xla_hlo.all_to_all"(%data) { + %0 = "mhlo.all_to_all"(%data) { split_dimension = 1 : i64, concat_dimension = 0 : i64, split_count = 4 : i64, @@ -28,7 +28,7 @@ func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { // CHECK-LABEL: func @alltoall_unranked_input func @alltoall_unranked_input(%data: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.all_to_all"(%data) { + %0 = "mhlo.all_to_all"(%data) { split_dimension = 1 : i64, concat_dimension = 0 : i64, split_count = 5 : i64, @@ -41,7 +41,7 @@ func @alltoall_unranked_input(%data: tensor<*xf32>) -> tensor<*xf32> { func @alltoall_invalid_split_dim_size(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { // expected-error@+1 {{split dimension has size 16, expected to be a multiple of split_count 5}} - %0 = "xla_hlo.all_to_all"(%data) { + %0 = "mhlo.all_to_all"(%data) { split_dimension = 1 : i64, concat_dimension = 0 : i64, split_count = 5 : i64, @@ -54,7 +54,7 @@ func @alltoall_invalid_split_dim_size(%data: tensor<4x16xf32>) -> tensor<16x4xf3 // CHECK-LABEL: func @broadcast func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -62,7 +62,7 @@ func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_sizes has rank 2 instead of rank 1}} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -70,7 +70,7 @@ func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_bad_result_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{result rank (3) does not match operand rank (1) plus size of broadcast_sizes (1)}} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -78,7 +78,7 @@ func @broadcast_bad_result_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_bad_first_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{result has shape [1, 3] instead of [2, 3]}} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x3xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x3xi32> return %0 : tensor<1x3xi32> } @@ -86,7 +86,7 @@ func @broadcast_bad_first_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{result has shape [2, 1] instead of [2, 3]}} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<2x1xi32> + %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<2x1xi32> return %0 : tensor<2x1xi32> } @@ -94,7 +94,7 @@ func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2 // CHECK-LABEL: func @broadcast_in_dim func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> return %0 : tensor<1x2x2xi32> } @@ -102,7 +102,7 @@ func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { // CHECK-LABEL: func @broadcast_in_dim_zero_rank func @broadcast_in_dim_zero_rank(%arg0: tensor) -> tensor<1x2x3xi32> { - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -110,7 +110,7 @@ func @broadcast_in_dim_zero_rank(%arg0: tensor) -> tensor<1x2x3xi32> { // CHECK-LABEL: func @dynamic_broadcast_in_dim func @dynamic_broadcast_in_dim(%arg0: tensor, %shape: tensor<3xi64>) -> tensor { - %0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xi64>) -> tensor + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xi64>) -> tensor return %0 : tensor } @@ -118,7 +118,7 @@ func @dynamic_broadcast_in_dim(%arg0: tensor, %shape: tensor<3xi64>) -> func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -126,7 +126,7 @@ func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x func @broadcast_in_dim_bad_dimension_size(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions size (1) does not match operand rank (2)}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -134,7 +134,7 @@ func @broadcast_in_dim_bad_dimension_size(%arg0: tensor<1x2xi32>) -> tensor<1x2x func @broadcast_in_dim_bad_rank_decrease(%arg0: tensor<1x2x3xi32>) -> tensor<3xi32> { // expected-error@+1 {{result rank (1) is less than operand rank (3)}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x3xi32>) -> tensor<3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x3xi32>) -> tensor<3xi32> return %0 : tensor<3xi32> } @@ -142,7 +142,7 @@ func @broadcast_in_dim_bad_rank_decrease(%arg0: tensor<1x2x3xi32>) -> tensor<3xi func @broadcast_in_dim_dimension_values_too_large(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions contains invalid value 9 for result result with rank 3}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[9, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[9, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -150,7 +150,7 @@ func @broadcast_in_dim_dimension_values_too_large(%arg0: tensor<1x2xi32>) -> ten func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{size of operand dimension 0 (3) is not equal to 1 or size of result dimension 1 (2)}} - %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -158,18 +158,18 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x func @case_mismatch_num_args(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { // expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}} - %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): - %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor, %arg1: tensor): - %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.copy"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () } ) : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor @@ -179,18 +179,18 @@ func @case_mismatch_num_args(%index: tensor, %operand_1: tensor, %oper func @case_mismatch_num_results(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { // expected-error@+1 {{branch 1 returned values do not match op result types}} - %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): - %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1, %arg0) : (tensor, tensor) -> () + %1 = "mhlo.copy"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1, %arg0) : (tensor, tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () } ) : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor @@ -200,18 +200,18 @@ func @case_mismatch_num_results(%index: tensor, %operand_1: tensor, %o func @case_mismatch_arg_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { // expected-error@+1 {{expects operand 2 to be of type 'tensor', but found 'tensor'}} - %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): - %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = xla_hlo.constant dense<2.0> : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.constant dense<2.0> : tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () } ) : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor @@ -221,18 +221,18 @@ func @case_mismatch_arg_type(%index: tensor, %operand_1: tensor, %oper func @case_mismatch_return_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { // expected-error@+1 {{branch 1 returned values do not match op result types}} - %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): - %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = xla_hlo.constant dense<2> : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.constant dense<2> : tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () } ) : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor @@ -242,7 +242,7 @@ func @case_mismatch_return_type(%index: tensor, %operand_1: tensor, %o func @case_empty_region(%index: tensor, %operand_1: tensor) -> () { // expected-error@+1 {{cannot have empty regions}} - "xla_hlo.case"(%index, %operand_1) ( {} ) : (tensor, tensor) -> tensor + "mhlo.case"(%index, %operand_1) ( {} ) : (tensor, tensor) -> tensor return } @@ -250,7 +250,7 @@ func @case_empty_region(%index: tensor, %operand_1: tensor) -> () { // CHECK-LABEL: func @comp_eq func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> return %0 : tensor<3xi1> } @@ -258,7 +258,7 @@ func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { func @comp_bad_direction(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { // expected-error@+1 {{'comparison_direction' failed to satisfy constraint}} - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "FOOBAR"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "FOOBAR"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> return %0 : tensor<3xi1> } @@ -266,7 +266,7 @@ func @comp_bad_direction(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3 func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{duplicate sources not allowed}} - %0 = "xla_hlo.collective_permute"(%arg0) { + %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1], [0, 2], [2, 3]]> : tensor<3x2xi64> } : (tensor<128x32xf32>) -> tensor<128x32xf32> return %0 : tensor<128x32xf32> @@ -276,7 +276,7 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor< func @collective_permute_duplicate_targets(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{duplicate targets not allowed}} - %0 = "xla_hlo.collective_permute"(%arg0) { + %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 1]]> : tensor<3x2xi64> } : (tensor<128x32xf32>) -> tensor<128x32xf32> return %0 : tensor<128x32xf32> @@ -286,7 +286,7 @@ func @collective_permute_duplicate_targets(%arg0: tensor<128x32xf32>) -> tensor< func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{expect source_target_pairs attribute to be of rank 2, but got rank 1}} - %0 = "xla_hlo.collective_permute"(%arg0) { + %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[0, 1]> : tensor<2xi64> } : (tensor<128x32xf32>) -> tensor<128x32xf32> return %0 : tensor<128x32xf32> @@ -296,7 +296,7 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor< func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{expect source_target_pairs attribute of shape (N, 2), but got (2, 3)}} - %0 = "xla_hlo.collective_permute"(%arg0) { + %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi64> } : (tensor<128x32xf32>) -> tensor<128x32xf32> return %0 : tensor<128x32xf32> @@ -306,15 +306,15 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor< // CHECK-LABEL: @concat_1D func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> return %0 : tensor<3xi32> } // ----- func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> { - // expected-error@+1 {{'xla_hlo.concatenate' op requires the same element type for all operands and results}} - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32> + // expected-error@+1 {{'mhlo.concatenate' op requires the same element type for all operands and results}} + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32> return %0 : tensor<3xi32> } @@ -322,23 +322,23 @@ func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tenso // CHECK-LABEL: @concat_1D_unranked func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> return %0 : tensor<*xi32> } // ----- func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> { - // expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}} - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32> + // expected-error@+1 {{'mhlo.concatenate' op inferred type incompatible with return type of operation}} + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32> return %0 : tensor<3xi32> } // ----- func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> { - // expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}} - %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32> + // expected-error@+1 {{'mhlo.concatenate' op inferred type incompatible with return type of operation}} + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -346,7 +346,7 @@ func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi // CHECK-LABEL: func @clamp func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> { - %0 = "xla_hlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %0 = "mhlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -354,15 +354,15 @@ func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @clamp_scalar func @clamp_scalar(%arg0: tensor<1xi32>, %arg1: tensor) -> tensor<1xi32> { - %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg1) : (tensor, tensor<1xi32>, tensor) -> tensor<1xi32> + %0 = "mhlo.clamp"(%arg1, %arg0, %arg1) : (tensor, tensor<1xi32>, tensor) -> tensor<1xi32> return %0: tensor<1xi32> } // ----- func @clamp_invalid_clamp_element_type(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> tensor<1xi32> { - // expected-error@+1 {{'xla_hlo.clamp' op requires the same element type for all operands and results}} - %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg0) : (tensor<1xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // expected-error@+1 {{'mhlo.clamp' op requires the same element type for all operands and results}} + %0 = "mhlo.clamp"(%arg1, %arg0, %arg0) : (tensor<1xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -370,7 +370,7 @@ func @clamp_invalid_clamp_element_type(%arg0: tensor<1xi32>, %arg1: tensor<1xf32 func @clamp_invalid_clamp_shape(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<1xi32> { // expected-error@+1 {{min shape [2] is not scalar and does not match operand shape [1]}} - %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg0) : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %0 = "mhlo.clamp"(%arg1, %arg0, %arg0) : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -378,7 +378,7 @@ func @clamp_invalid_clamp_shape(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> t // CHECK-LABEL: func @dot_vector func @dot_vector(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) -> tensor { - %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor return %0: tensor } @@ -386,7 +386,7 @@ func @dot_vector(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) -> tensor // CHECK-LABEL: func @dot_matrix func @dot_matrix(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0: tensor<2x2xi32> } @@ -394,7 +394,7 @@ func @dot_matrix(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi // CHECK-LABEL: func @dot_precision_config func @dot_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0: tensor<2x2xi32> } @@ -402,23 +402,23 @@ func @dot_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> te func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { // expected-error@+1 {{'precision_config' failed to satisfy constraint}} - %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["FOO", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["FOO", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0: tensor<2x2xi32> } // ----- -func @infeed_invalid_number_of_results(%token: !xla_hlo.token) -> tuple>, !xla_hlo.token, tensor> { +func @infeed_invalid_number_of_results(%token: !mhlo.token) -> tuple>, !mhlo.token, tensor> { // expected-error@+1 {{result is expected to be a tuple of size 2, but got 3}} - %0 = "xla_hlo.infeed"(%token) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple>, !xla_hlo.token, tensor> - return %0 : tuple>, !xla_hlo.token, tensor> + %0 = "mhlo.infeed"(%token) {infeed_config = "foobar"} : (!mhlo.token) -> tuple>, !mhlo.token, tensor> + return %0 : tuple>, !mhlo.token, tensor> } // ----- -func @infeed_non_token_second_result(%token: !xla_hlo.token) -> tuple>, tensor> { +func @infeed_non_token_second_result(%token: !mhlo.token) -> tuple>, tensor> { // expected-error@+1 {{second element of result tuple is expected to be of token type, but got 'tensor'}} - %0 = "xla_hlo.infeed"(%token) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple>, tensor> + %0 = "mhlo.infeed"(%token) {infeed_config = "foobar"} : (!mhlo.token) -> tuple>, tensor> return %0 : tuple>, tensor> } @@ -426,7 +426,7 @@ func @infeed_non_token_second_result(%token: !xla_hlo.token) -> tuple tensor { // expected-error@+1 {{does not support scalars}} - %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor return %0 : tensor } @@ -434,7 +434,7 @@ func @iota_scalar() -> tensor { func @iota_invalid_iota_dimension() -> tensor<4xi32> { // expected-error@+1 {{iota dimension cannot go beyond the output rank or be negative}} - %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4xi32> + %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -442,10 +442,10 @@ func @iota_invalid_iota_dimension() -> tensor<4xi32> { func @map_mismatched_args(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // expected-error@+1 {{expects number of operands to match the arity of map computation, but got: 2 and 1}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg: tensor): - %1 = xla_hlo.add %arg, %arg {name = "add"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg, %arg {name = "add"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -454,10 +454,10 @@ func @map_mismatched_args(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor< func @map_non_scalar_computation_operand(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{computation arguments must be 0-rank tensor, but got: arg #1 of type 'tensor<5xf32>'}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor<5xf32>): - %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -466,10 +466,10 @@ func @map_non_scalar_computation_operand(%arg0: tensor<4x5xf32>, %arg1: tensor<4 func @map_mismatch_operand_and_computation_args(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{element type of operands and computation arguments must match, but got: 'f32' and 'i32'}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -478,10 +478,10 @@ func @map_mismatch_operand_and_computation_args(%arg0: tensor<4x5xf32>, %arg1: t func @map_invalid_number_of_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{computation must return single output, but got: 0}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor - "xla_hlo.return"() : () -> () + %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + "mhlo.return"() : () -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -490,10 +490,10 @@ func @map_invalid_number_of_computation_output(%arg0: tensor<4x5xf32>, %arg1: te func @main_non_scalar_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{computation must return 0-rank tensor, but got: 'tensor<5xf32>'}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor<5xf32> - "xla_hlo.return"(%1) : (tensor<5xf32>) -> () + %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor<5xf32> + "mhlo.return"(%1) : (tensor<5xf32>) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -502,10 +502,10 @@ func @main_non_scalar_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4 func @mismatch_computation_output_type(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{element type of result and computation output must match, but got: 'f32' and 'i32'}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.constant {value = dense<2> : tensor} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.constant {value = dense<2> : tensor} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -514,10 +514,10 @@ func @mismatch_computation_output_type(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5 func @map_invalid_dimension_numbers(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{requires monotonically increasing dimension numbers, but got: dense<[1, 0]> : tensor<2xi64>}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -526,10 +526,10 @@ func @map_invalid_dimension_numbers(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf3 func @map_mismatch_arguments_and_dimensions(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{applied to a subset of dimensions currently not supported: operand dimensions = 2, requested map dimensions size = 3}} - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } @@ -538,48 +538,48 @@ func @map_mismatch_arguments_and_dimensions(%arg0: tensor<4x5xf32>, %arg1: tenso // CHECK-LABEL: func @map_unranked func @map_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.map"(%arg0, %arg1) ( { + %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor - "xla_hlo.return"(%1) : (tensor) -> () + %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor + "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // ----- -func @recv_invalid_number_of_results(%token: !xla_hlo.token) -> tuple, tensor, !xla_hlo.token> { +func @recv_invalid_number_of_results(%token: !mhlo.token) -> tuple, tensor, !mhlo.token> { // expected-error@+1 {{result is expected to be a tuple of size 2, but got 3}} - %0 = "xla_hlo.recv"(%token) { + %0 = "mhlo.recv"(%token) { channel_id = { handle = 5 : i64, type = 3 : i64 // Host to device channel }, is_host_transfer = true - } : (!xla_hlo.token) -> tuple, tensor, !xla_hlo.token> - return %0 : tuple, tensor, !xla_hlo.token> + } : (!mhlo.token) -> tuple, tensor, !mhlo.token> + return %0 : tuple, tensor, !mhlo.token> } // ----- -func @recv_non_token_second_result(%token: !xla_hlo.token) -> tuple, tensor> { +func @recv_non_token_second_result(%token: !mhlo.token) -> tuple, tensor> { // expected-error@+1 {{second element of result tuple is expected to be of token type, but got 'tensor'}} - %0 = "xla_hlo.recv"(%token) { + %0 = "mhlo.recv"(%token) { channel_id = { handle = 5 : i64, type = 3 : i64 // Host to device channel }, is_host_transfer = true - } : (!xla_hlo.token) -> tuple, tensor> + } : (!mhlo.token) -> tuple, tensor> return %0 : tuple, tensor> } // ----- func @rng_uniform_invalid_type(%mu: tensor>, %sigma: tensor) -> tensor<2x3x5xf32> { - %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> + %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{but got 'tensor>'}} - %0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng_uniform"(%mu, %sigma, %shape) : (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> return %0 : tensor<2x3x5xf32> } @@ -587,7 +587,7 @@ func @rng_uniform_invalid_type(%mu: tensor>, %sigma: tensor) - // CHECK-LABEL: func @select func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -595,7 +595,7 @@ func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi3 // CHECK-LABEL: func @select_scalar_pred func @select_scalar_pred(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -603,7 +603,7 @@ func @select_scalar_pred(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tenso // CHECK-LABEL: func @select_cast_compatible_types func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<*xi32>, %arg2: tensor<2x3xi32>) -> tensor<*xi32> { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<*xi32>, tensor<2x3xi32>) -> tensor<*xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<*xi32>, tensor<2x3xi32>) -> tensor<*xi32> return %0 : tensor<*xi32> } @@ -612,7 +612,7 @@ func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<*xi32>, %arg func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<2x?xi32>, %arg2: tensor) -> tensor { // TODO(lucyfox): Update once this is supported. // expected-error@+1 {{currently unsupported operand types: 'tensor<2x?xi32>' and 'tensor'}} - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x?xi32>, tensor) -> tensor + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x?xi32>, tensor) -> tensor return %0 : tensor } @@ -620,7 +620,7 @@ func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<2x?xi32>, %a // CHECK-LABEL: func @select_scalar_x_y func @select_scalar_x_y(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -628,7 +628,7 @@ func @select_scalar_x_y(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) values}} - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -636,7 +636,7 @@ func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // expected-error@+1 {{incompatible operand types: 'tensor<2x4xi32>' and 'tensor<2x3xi32>'}} - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -644,7 +644,7 @@ func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x4xi32>, %ar func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // expected-error@+1 {{incompatible operand types: 'tensor<2x3xf32>' and 'tensor<2x3xi32>'}} - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -652,7 +652,7 @@ func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf3 // CHECK-LABEL: func @slice func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { - %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -660,7 +660,7 @@ func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { // expected-error@+1 {{failed to verify that all of {start_indices, limit_indices, strides} have same type}} - %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -668,7 +668,7 @@ func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> { // expected-error@+1 {{requires the same element type for all operands and results}} - %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32> + %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -676,7 +676,7 @@ func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> { // CHECK-LABEL: func @dynamic_slice func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -684,7 +684,7 @@ func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // expected-error@+1 {{has mismatched number of slice sizes (1) and number of start indices (2)}} - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -692,7 +692,7 @@ func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor, // CHECK-LABEL: @dynamic_slice_different_indice_element_type func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor) -> tensor<1x4xi32> { - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -700,7 +700,7 @@ func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xf32> { // expected-error@+1 {{failed to verify that all of {operand, result} have same element type}} - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xf32> + %0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -708,7 +708,7 @@ func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { // expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> + %0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -716,7 +716,7 @@ func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) // CHECK-LABEL: @dynamic_update_slice func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start1: tensor, %start2: tensor) -> tensor<3x4xi64> { - %0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor, tensor) -> tensor<3x4xi64> + %0 = "mhlo.dynamic-update-slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor, tensor) -> tensor<3x4xi64> return %0 : tensor<3x4xi64> } @@ -724,7 +724,7 @@ func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %sta func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { // expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} - %0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> + %0 = "mhlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> return %0 : tensor<3x4xi64> } @@ -732,21 +732,21 @@ func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tenso // CHECK-LABEL: func @transpose func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } // ----- func @transpose_ranked(%arg0: tensor) -> tensor { - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor) -> tensor + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor) -> tensor return %0: tensor } // ----- func @transpose_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<*xi32>) -> tensor<*xi32> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<*xi32>) -> tensor<*xi32> return %0: tensor<*xi32> } @@ -754,7 +754,7 @@ func @transpose_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { func @transpose_bad_permutations_rank(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // expected-error@+1 {{permutation has rank 2 instead of rank 1}} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[[1]]> : tensor<1x1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[[1]]> : tensor<1x1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } @@ -762,7 +762,7 @@ func @transpose_bad_permutations_rank(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1 func @transpose_bad_permutations_size(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // expected-error@+1 {{operand rank (4) does not match permutation size (1)}} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1]> : tensor<1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1]> : tensor<1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } @@ -770,7 +770,7 @@ func @transpose_bad_permutations_size(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1 func @transpose_operand_result_rank_mismatch(%arg0: tensor<1x2x3x4xi32>) -> tensor<2xi32> { // expected-error@+1 {{result rank (1) does not match permutation size (4)}} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2xi32> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } @@ -778,14 +778,14 @@ func @transpose_operand_result_rank_mismatch(%arg0: tensor<1x2x3x4xi32>) -> ten func @transpose_operand_result_permutation_mismatch(%arg0: tensor<1x?x3x?xi32>) -> tensor { // expected-error@+1 {{result type tensor is incompatible with the expected type tensor}} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x?x3x?xi32>) -> tensor + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x?x3x?xi32>) -> tensor return %0: tensor } // ----- func @triangular_solve_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -793,7 +793,7 @@ func @triangular_solve_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> t func @triangular_solve_rank_less_than_2(%arg0: tensor<4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { // expected-error@+1 {{operand 'a' must have rank >= 2, but got 'tensor<4xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> return %0 : tensor<4x3xf32> } @@ -801,7 +801,7 @@ func @triangular_solve_rank_less_than_2(%arg0: tensor<4xf32>, %arg1: tensor<4x3x func @triangular_solve_unequal_minor_dims_a(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { // expected-error@+1 {{two minor dimensions of operand 'a' must have equal size, but got 'tensor<4x3xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> return %0 : tensor<4x3xf32> } @@ -809,7 +809,7 @@ func @triangular_solve_unequal_minor_dims_a(%arg0: tensor<4x3xf32>, %arg1: tenso func @triangular_solve_unequal_rank(%arg0: tensor<10x4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { // expected-error@+1 {{operands must have equal rank, but got 'tensor<10x4x4xf32>' and 'tensor<4x3xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> return %0 : tensor<4x3xf32> } @@ -817,7 +817,7 @@ func @triangular_solve_unequal_rank(%arg0: tensor<10x4x4xf32>, %arg1: tensor<4x3 func @triangular_solve_mismatch_shared_dim(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf32> { // expected-error@+1 {{shared dimension of operands 'a' and 'b' does not match, but got 'tensor<4x4xf32>' and 'tensor<3x4xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> return %0 : tensor<3x4xf32> } @@ -825,7 +825,7 @@ func @triangular_solve_mismatch_shared_dim(%arg0: tensor<4x4xf32>, %arg1: tensor func @triangular_solve_mismatch_leading_dims(%arg0: tensor<10x5x4x4xf32>, %arg1: tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> { // expected-error@+1 {{leading batch dimensions of the operands must be same, but got 'tensor<10x5x4x4xf32>' and 'tensor<10x6x4x3xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x5x4x4xf32>, tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x5x4x4xf32>, tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> return %0 : tensor<10x6x4x3xf32> } @@ -833,7 +833,7 @@ func @triangular_solve_mismatch_leading_dims(%arg0: tensor<10x5x4x4xf32>, %arg1: func @triangular_solve_mismatch_result_and_b_type(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{result and operand 'b' must have same shape, but got 'tensor<4x4xf32>' and 'tensor<4x3xf32>'}} - %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x4xf32> + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x4xf32> return %0 : tensor<4x4xf32> } @@ -841,7 +841,7 @@ func @triangular_solve_mismatch_result_and_b_type(%arg0: tensor<4x4xf32>, %arg1: // CHECK-LABEL: func @tuple func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> { - %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> + %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> return %0: tuple, tensor<1x2xf32>> } @@ -849,7 +849,7 @@ func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple func @tuple_arg_size_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, tensor, tensor> { // expected-error@+1 {{has return type tuple, tensor, tensor>, but expected tuple, tensor>}} - %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor, tensor> + %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor, tensor> return %0 : tuple, tensor, tensor> } @@ -857,29 +857,29 @@ func @tuple_arg_size_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, %arg1: tensor) -> tuple, tensor> { // expected-error@+1 {{has return type tuple, tensor>, but expected tuple, tensor>}} - %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor> + %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor> return %0 : tuple, tensor> } // ----- func @get_tuple_element(%arg0: tuple, tensor>) -> tensor { - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor return %0 : tensor } // ----- -func @get_tuple_element_token(%arg0: tuple, !xla_hlo.token>) -> !xla_hlo.token { - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, !xla_hlo.token>) -> !xla_hlo.token - return %0 : !xla_hlo.token +func @get_tuple_element_token(%arg0: tuple, !mhlo.token>) -> !mhlo.token { + %0 = "mhlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, !mhlo.token>) -> !mhlo.token + return %0 : !mhlo.token } // ----- func @get_tuple_element_bad_type(%arg0: tuple, tensor>) -> tensor { // expected-error@+1 {{has return type tensor, but expected tensor}} - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor return %0 : tensor } @@ -887,7 +887,7 @@ func @get_tuple_element_bad_type(%arg0: tuple, tensor>) -> tens func @get_tuple_element_index_out_of_bounds(%arg0: tuple, tensor>) -> tensor { // expected-error@+1 {{index 2 is out of bounds of operand with size 2}} - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, tensor>) -> tensor + %0 = "mhlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, tensor>) -> tensor return %0 : tensor } @@ -895,14 +895,14 @@ func @get_tuple_element_index_out_of_bounds(%arg0: tuple, tensor, %arg1: tensor<4xi32>) -> tensor<4xi32> { - %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %0 = "mhlo.and"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } // ----- // CHECK-LABEL: func @or_i1_type func @or_i1_type(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + %0 = "mhlo.or"(%arg0, %arg1) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } @@ -910,7 +910,7 @@ func @or_i1_type(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // expected-error@+1 {{but got 'tensor<4xf32>'}} - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "mhlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -918,7 +918,7 @@ func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor< func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { // expected-error@+1 {{must be tensor of floating-point values, but got 'tensor<4xi32>'}} - %0 = "xla_hlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %0 = "mhlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -927,11 +927,11 @@ func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { // Verifiers HLO constant op custom printing and parsing. // CHECK-LABEL: func @constants func @constants() -> () { - // CHECK: xla_hlo.constant dense<0> : tensor - %0 = "xla_hlo.constant"() {value = dense<0> : tensor} : () -> (tensor) + // CHECK: mhlo.constant dense<0> : tensor + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> (tensor) - // CHECK: xla_hlo.constant {extra_attr = 3 : i32} dense<0> : tensor - %1 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor) + // CHECK: mhlo.constant {extra_attr = 3 : i32} dense<0> : tensor + %1 = "mhlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor) return } @@ -939,18 +939,18 @@ func @constants() -> () { func @constant_invalid() -> () { // expected-error@+1 {{op failed to verify that all of {value, output} have same type}} - %0 = "xla_hlo.constant"() {value = dense<0> : tensor} : () -> (tensor<*xi32>) + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> (tensor<*xi32>) return } // ----- func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { - // CHECK: xla_hlo.sort - %0 = "xla_hlo.sort"(%input0, %input1) ( { + // CHECK: mhlo.sort + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -959,10 +959,10 @@ func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { func @sort_no_operands() { // expected-error @+1 {{op requires at least one input}} - %0 = "xla_hlo.sort"() ( { + %0 = "mhlo.sort"() ( { ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): - %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : () -> tuple<> return } @@ -970,10 +970,10 @@ func @sort_no_operands() { // ----- func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -982,10 +982,10 @@ func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{comparator block argument #0 should be of type 'tensor' but got 'tensor'}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -994,10 +994,10 @@ func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op requires all inputs to have the same dimensions}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -1006,10 +1006,10 @@ func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found 10}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -1018,10 +1018,10 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3 func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found -3}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -1030,10 +1030,10 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3 func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op comparator block should have 4 arguments}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -1042,10 +1042,10 @@ func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op comparator block argument #3 should be of type 'tensor' but got 'tensor'}} - %0 = "xla_hlo.sort"(%input0, %input1) ( { + %0 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%7) : (tensor) -> () + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> return } @@ -1054,7 +1054,7 @@ func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x1 // CHECK: func @dequantize func @dequantize(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { - %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> + %0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> return %0 : tensor<16x64xbf16> } @@ -1062,7 +1062,7 @@ func @dequantize(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { func @dequantize_wrong_shape(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { // expected-error @+1 {{mismatched dimensions.}} - %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = true} : (tensor<16x16xi32>) -> tensor<16x64xbf16> + %0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = true} : (tensor<16x16xi32>) -> tensor<16x64xbf16> return %0 : tensor<16x64xbf16> } @@ -1070,7 +1070,7 @@ func @dequantize_wrong_shape(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { func @dequantize_wrong_size(%arg: tensor<16x16xi32>) -> tensor<16x16xbf16> { // expected-error @+1 {{last dimension of output should be 4x of the input.}} - %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x16xbf16> + %0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x16xbf16> return %0 : tensor<16x16xbf16> } @@ -1078,7 +1078,7 @@ func @dequantize_wrong_size(%arg: tensor<16x16xi32>) -> tensor<16x16xbf16> { func @dequantize_wrong_mode(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { // expected-error @+1 {{Dequantization mode. Only MIN_COMBINED is supported.}} - %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "hello", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> + %0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "hello", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> return %0 : tensor<16x64xbf16> } @@ -1086,7 +1086,7 @@ func @dequantize_wrong_mode(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> { // expected-error @+1 {{number of output elements (9) doesn't match expected number of elements (8)}} - %0 = "xla_hlo.reshape"(%operand) : (tensor<2x4xf32>) -> tensor<3x3xf32> + %0 = "mhlo.reshape"(%operand) : (tensor<2x4xf32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> } @@ -1094,7 +1094,7 @@ func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> { func @dot_general(%arg0: tensor, %arg1: tensor) { // expected-error @+1 {{lhs and rhs should have the same number of batching dimensions}} - %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { + %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, @@ -1107,7 +1107,7 @@ func @dot_general(%arg0: tensor, %arg1: tensor) { func @dot_general(%arg0: tensor, %arg1: tensor) { // expected-error @+1 {{lhs and rhs should have the same number of contracting dimensions}} - %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { + %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[]> : tensor<0xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, @@ -1119,7 +1119,7 @@ func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func @compatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor { - %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor + %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor return %0 : tensor } @@ -1127,6 +1127,6 @@ func @compatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor func @incompatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor { // expected-error @+1 {{output should have a rank equal to the number of elements in output_shape}} - %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor + %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor return %0 : tensor } diff --git a/tests/reduce.mlir b/tests/reduce.mlir index 4566b63..586a199 100644 --- a/tests/reduce.mlir +++ b/tests/reduce.mlir @@ -4,11 +4,11 @@ // CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>) // CHECK: return %[[ARG0]] func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %2 = "xla_hlo.reduce"(%arg0, %0) ( { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %2 = "mhlo.reduce"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): - %4 = xla_hlo.add %arg1, %arg2 : tensor - "xla_hlo.return"(%4) : (tensor) -> () + %4 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%4) : (tensor) -> () }) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> } diff --git a/tests/reshape.mlir b/tests/reshape.mlir index c9e6c5a..9aa28a4 100644 --- a/tests/reshape.mlir +++ b/tests/reshape.mlir @@ -2,9 +2,9 @@ // CHECK-LABEL: func @const_fold_collapse_to_scalar func @const_fold_collapse_to_scalar() -> tensor { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor - %cst = xla_hlo.constant dense<42> : tensor<1x1xi32> - %0 = "xla_hlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor + %cst = mhlo.constant dense<42> : tensor<1x1xi32> + %0 = "mhlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor // CHECK-NEXT: return [[CST]] return %0 : tensor } @@ -13,9 +13,9 @@ func @const_fold_collapse_to_scalar() -> tensor { // CHECK-LABEL: func @const_fold_collapse_to_tensor func @const_fold_collapse_to_tensor() -> tensor<2xi32> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<2xi32> - %cst = xla_hlo.constant dense<42> : tensor<1x2xi32> - %0 = "xla_hlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32> + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<2xi32> + %cst = mhlo.constant dense<42> : tensor<1x2xi32> + %0 = "mhlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32> // CHECK-NEXT: return [[CST]] return %0 : tensor<2xi32> } @@ -24,9 +24,9 @@ func @const_fold_collapse_to_tensor() -> tensor<2xi32> { // CHECK-LABEL: func @const_fold_expand func @const_fold_expand() -> tensor<1xi32> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<1xi32> - %cst = xla_hlo.constant dense<42> : tensor - %0 = "xla_hlo.reshape"(%cst) : (tensor) -> tensor<1xi32> + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<1xi32> + %cst = mhlo.constant dense<42> : tensor + %0 = "mhlo.reshape"(%cst) : (tensor) -> tensor<1xi32> // CHECK-NEXT: return [[CST]] return %0 : tensor<1xi32> } @@ -35,9 +35,9 @@ func @const_fold_expand() -> tensor<1xi32> { // CHECK-LABEL: func @const_fold_nontrivial func @const_fold_nontrivial() -> tensor<16xi64> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64> - %cst = xla_hlo.constant dense<42> : tensor<4x4xi64> - %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64> + %cst = mhlo.constant dense<42> : tensor<4x4xi64> + %0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> // CHECK-NEXT: return [[CST]] return %0 : tensor<16xi64> } @@ -46,9 +46,9 @@ func @const_fold_nontrivial() -> tensor<16xi64> { // CHECK-LABEL: func @const_fold_flatten func @const_fold_flatten() -> tensor<16xi64> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64> - %cst = xla_hlo.constant dense<42> : tensor<4x4xi64> - %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64> + %cst = mhlo.constant dense<42> : tensor<4x4xi64> + %0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> // CHECK-NEXT: return [[CST]] return %0 : tensor<16xi64> } @@ -57,9 +57,9 @@ func @const_fold_flatten() -> tensor<16xi64> { // CHECK-LABEL: func @const_fold_6 func @const_fold_6() -> tensor<6xi32> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> - %cst = xla_hlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> - %0 = "xla_hlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32> + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> + %cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> + %0 = "mhlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32> // CHECK-NEXT: return [[CST]] return %0 : tensor<6xi32> } @@ -68,11 +68,11 @@ func @const_fold_6() -> tensor<6xi32> { // CHECK-LABEL: func @const_fold_same_shape func @const_fold_same_shape() -> tensor<2x3xi32> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[ + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[ // CHECK-SAME: [1, 2, 3], [4, 5, 6] // CHECK-SAME: ]> : tensor<2x3xi32> - %cst = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> - %0 = "xla_hlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32> + %cst = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> + %0 = "mhlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32> // CHECK-NEXT: return [[CST]] return %0 : tensor<2x3xi32> } @@ -81,9 +81,9 @@ func @const_fold_same_shape() -> tensor<2x3xi32> { // CHECK-LABEL: func @const_fold_float func @const_fold_float() -> tensor<16xf64> { - // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64> - %cst = xla_hlo.constant dense<4.2> : tensor<4x4xf64> - %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64> + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64> + %cst = mhlo.constant dense<4.2> : tensor<4x4xf64> + %0 = "mhlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64> // CHECK-NEXT: return [[CST]] return %0 : tensor<16xf64> } @@ -94,7 +94,7 @@ func @const_fold_float() -> tensor<16xf64> { // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK-NEXT: return [[ARG]] - %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -103,10 +103,10 @@ func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK-LABEL: func @non_const_chained_reshape // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) { - // CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32> - // CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> - %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> - %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> + // CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32> + // CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> + %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed } @@ -115,9 +115,9 @@ func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, ten // CHECK-LABEL: func @non_const_chained_reshape_unused_parent // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> - %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> - %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> + // CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> + %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> // CHECK-NEXT: return [[RES]] return %1 : tensor<6xi32> } @@ -127,8 +127,8 @@ func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor< // CHECK-LABEL: func @non_const_chained_reshape_becomes_noop // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { - %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> - %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32> + %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32> // CHECK-NEXT: return [[ARG]] return %1 : tensor<2x3xi32> } @@ -138,12 +138,12 @@ func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2 // CHECK-LABEL: func @non_const_many_chained_reshapes // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> { - // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> - %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32> - %1 = "xla_hlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32> - %2 = "xla_hlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32> - %3 = "xla_hlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32> - %4 = "xla_hlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32> + // CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> + %0 = "mhlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32> + %1 = "mhlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32> + %2 = "mhlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32> + %3 = "mhlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32> + %4 = "mhlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32> // CHECK-NEXT: return [[RES]] return %4 : tensor<1x2x4x3xi32> } diff --git a/tests/reverse.mlir b/tests/reverse.mlir index 9a1c113..6e291af 100644 --- a/tests/reverse.mlir +++ b/tests/reverse.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func @noop // CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>) func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { - %0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %0 = "mhlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32> // CHECK: return %[[ARG0]] return %0 : tensor<1x2xf32> } diff --git a/tests/sink-constants-to-control-flow.mlir b/tests/sink-constants-to-control-flow.mlir index 35682a5..6a35239 100644 --- a/tests/sink-constants-to-control-flow.mlir +++ b/tests/sink-constants-to-control-flow.mlir @@ -4,27 +4,27 @@ // CHECK-LABEL: func @sink_const_to_while func @sink_const_to_while(%arg0: tensor) -> tensor { - // CHECK-NEXT: xla_hlo.while - %c0 = xla_hlo.constant dense<1> : tensor - %c1 = xla_hlo.constant dense<2> : tensor - %0 = "xla_hlo.while"(%arg0) ( { + // CHECK-NEXT: mhlo.while + %c0 = mhlo.constant dense<1> : tensor + %c1 = mhlo.constant dense<2> : tensor + %0 = "mhlo.while"(%arg0) ( { ^bb0(%arg1: tensor): // CHECK: %[[ARG1A:.+]]: tensor - // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor - // CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]]) - %1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%1) : (tensor) -> () + // CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor + // CHECK: "mhlo.compare"(%[[C0]], %[[ARG1A]]) + %1 = "mhlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): // CHECK: %[[ARG1B:.+]]: tensor - // CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor - // CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]] - %2 = xla_hlo.add %arg1, %arg1 : tensor - // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]] - %3 = xla_hlo.add %c1, %2 : tensor - // CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]] - %4 = xla_hlo.add %c1, %3 : tensor - "xla_hlo.return"(%4) : (tensor) -> () + // CHECK-DAG: %[[C1:.+]] = mhlo.constant dense<2> : tensor + // CHECK-DAG: %[[ADD0:.+]] = mhlo.add %[[ARG1B]], %[[ARG1B]] + %2 = mhlo.add %arg1, %arg1 : tensor + // CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], %[[ADD0]] + %3 = mhlo.add %c1, %2 : tensor + // CHECK: %[[ADD2:.+]] = mhlo.add %[[C1]], %[[ADD1]] + %4 = mhlo.add %c1, %3 : tensor + "mhlo.return"(%4) : (tensor) -> () }) : (tensor) -> tensor return %0 : tensor } @@ -33,28 +33,28 @@ func @sink_const_to_while(%arg0: tensor) -> tensor { // CHECK-LABEL: func @sink_const_to_conditional func @sink_const_to_conditional(%arg0: tensor) -> tensor { - %c0 = xla_hlo.constant dense<1> : tensor - %c1 = xla_hlo.constant dense<2> : tensor - %0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - %1 = "xla_hlo.tuple"(%arg0) : (tensor) -> tuple> - // CHECK: xla_hlo.if - %2 = "xla_hlo.if"(%0, %1, %1) ( { + %c0 = mhlo.constant dense<1> : tensor + %c1 = mhlo.constant dense<2> : tensor + %0 = "mhlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + %1 = "mhlo.tuple"(%arg0) : (tensor) -> tuple> + // CHECK: mhlo.if + %2 = "mhlo.if"(%0, %1, %1) ( { ^bb0(%arg1: tuple>): - // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor - %3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor - // CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]], - %4 = xla_hlo.add %c0, %3 : tensor - %5 = "xla_hlo.tuple"(%4) : (tensor) -> tuple> - "xla_hlo.return"(%5) : (tuple>) -> () + // CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor + %3 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD0:.+]] = mhlo.add %[[C0]], + %4 = mhlo.add %c0, %3 : tensor + %5 = "mhlo.tuple"(%4) : (tensor) -> tuple> + "mhlo.return"(%5) : (tuple>) -> () }, { ^bb0(%arg1: tuple>): - // CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor - %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor - // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], - %7 = xla_hlo.add %c1, %6 : tensor - %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple> - "xla_hlo.return"(%8) : (tuple>) -> () + // CHECK: %[[C1:.+]] = mhlo.constant dense<2> : tensor + %6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], + %7 = mhlo.add %c1, %6 : tensor + %8 = "mhlo.tuple"(%7) : (tensor) -> tuple> + "mhlo.return"(%8) : (tuple>) -> () }) : (tensor, tuple>, tuple>) -> tuple> - %9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor + %9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor return %9 : tensor } diff --git a/tests/transpose.mlir b/tests/transpose.mlir index ce11a2a..bbfedc5 100644 --- a/tests/transpose.mlir +++ b/tests/transpose.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func @remove_noop // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> { - %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> + %0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> // CHECK-NEXT: return [[ARG]] return %0 : tensor<2x3x9x5xi32> } @@ -13,8 +13,8 @@ func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> { // CHECK-LABEL: func @keep_real_transpose // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { - // CHECK-NEXT: "xla_hlo.transpose"([[ARG]]) - %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> + // CHECK-NEXT: "mhlo.transpose"([[ARG]]) + %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> return %0 : tensor<3x2x5x9xi32> } @@ -23,7 +23,7 @@ func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { // CHECK-LABEL: func @keep_same_shape_real_transpose // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> { - // CHECK-NEXT: "xla_hlo.transpose"([[ARG]]) - %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32> + // CHECK-NEXT: "mhlo.transpose"([[ARG]]) + %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32> return %0 : tensor<4x4xi32> } diff --git a/tests/tuple.mlir b/tests/tuple.mlir index bf68009..4ecc1e3 100644 --- a/tests/tuple.mlir +++ b/tests/tuple.mlir @@ -4,7 +4,7 @@ // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @fold_access(%arg : tensor) -> tensor { // CHECK-NEXT: return [[ARG]] - %tuple = "xla_hlo.tuple"(%arg) : (tensor) -> tuple> - %element = "xla_hlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple>) -> tensor + %tuple = "mhlo.tuple"(%arg) : (tensor) -> tuple> + %element = "mhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple>) -> tensor return %element : tensor } diff --git a/tests/unfuse_batch_norm.mlir b/tests/unfuse_batch_norm.mlir index cefceeb..296bba7 100644 --- a/tests/unfuse_batch_norm.mlir +++ b/tests/unfuse_batch_norm.mlir @@ -10,19 +10,19 @@ func @batchNormInference_2D_inner_features( %x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor<4x256xf32>) { - // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor - // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<256xf32> - // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> - // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32> - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor + // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<256xf32> + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> + // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32> + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> tensor<4x256xf32> @@ -36,12 +36,12 @@ func @batchNormInference_2D_inner_features( // the verifier to enforce the rest. // CHECK-SAME: %[[X:[^:]+]] // CHECK-SAME: %[[SCALE:[^:]+]] -// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> +// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> func @batchNormInference_4D_middle_features( %x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor<3x4x256x6xf32>) { - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} : (tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> tensor<3x4x256x6xf32> @@ -51,12 +51,12 @@ func @batchNormInference_4D_middle_features( // ----- // CHECK-LABEL: @batchNormInference_f64 // Validate that epsilon is properly promoted to f64 -// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor +// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor func @batchNormInference_f64( %x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>, %mean: tensor<256xf64>, %variance: tensor<256xf64>) -> (tensor<4x256xf64>) { - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 1.0 : f32, feature_index = 1 : i64} : (tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>) -> tensor<4x256xf64> @@ -66,12 +66,12 @@ func @batchNormInference_f64( // ----- // CHECK-LABEL: @batchNormInference_f16 // Validate that epsilon is properly promoted to f64 -// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor +// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor func @batchNormInference_f16( %x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>, %mean: tensor<256xf16>, %variance: tensor<256xf16>) -> (tensor<4x256xf16>) { - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 1.0 : f32, feature_index = 1 : i64} : (tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>) -> tensor<4x256xf16> @@ -85,7 +85,7 @@ func @batchNormInference_f16_overflow( %mean: tensor<256xf16>, %variance: tensor<256xf16>) -> (tensor<4x256xf16>) { // expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}} - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 0.00000001 : f32, feature_index = 1 : i64} : (tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>) -> tensor<4x256xf16> @@ -108,26 +108,26 @@ func @batchNormInference_dynamic_shape( // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index - // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor + // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex> - // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor - // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor - // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor + // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor + // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor - // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor - // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor - // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor - %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor + // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor + // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor + // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 0.001 : f32, feature_index = 1 : i64} : (tensor, tensor, tensor, tensor, tensor) -> tensor diff --git a/tests/xla-hlo-fusion.mlir b/tests/xla-hlo-fusion.mlir index 7061bc2..6dc079a 100644 --- a/tests/xla-hlo-fusion.mlir +++ b/tests/xla-hlo-fusion.mlir @@ -2,14 +2,14 @@ // CHECK-LABEL: func @multi_outputs_same func @multi_outputs_same(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "xla_hlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor - %2 = "xla_hlo.add"(%1, %1) : (tensor, tensor) -> tensor - // CHECK: %[[RET:.*]]:2 = "xla_hlo.fusion" - // CHECK-NEXT: xla_hlo.add - // CHECK-NEXT: xla_hlo.subtract - // CHECK-NEXT: xla_hlo.add - // CHECK-NEXT: xla_hlo.return + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "mhlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor + %2 = "mhlo.add"(%1, %1) : (tensor, tensor) -> tensor + // CHECK: %[[RET:.*]]:2 = "mhlo.fusion" + // CHECK-NEXT: mhlo.add + // CHECK-NEXT: mhlo.subtract + // CHECK-NEXT: mhlo.add + // CHECK-NEXT: mhlo.return return %1, %2 : tensor, tensor } @@ -17,18 +17,18 @@ func @multi_outputs_same(%arg0: tensor, %arg1: tensor) -> (ten // CHECK-LABEL: func @multi_outputs_same_2 func @multi_outputs_same_2(%arg0: tensor, %arg1: tensor) -> (tensor, tensor, tensor) { - %0 = "xla_hlo.abs"(%arg0) : (tensor) -> tensor - %1 = "xla_hlo.abs"(%arg1) : (tensor) -> tensor - %2 = "xla_hlo.add"(%0, %1) : (tensor, tensor) -> tensor - %3 = "xla_hlo.abs"(%0) : (tensor) -> tensor - %4 = "xla_hlo.abs"(%1) : (tensor) -> tensor - // CHECK: %[[RET:.*]]:3 = "xla_hlo.fusion" - // CHECK-NEXT: xla_hlo.abs - // CHECK-NEXT: xla_hlo.abs - // CHECK-NEXT: xla_hlo.add - // CHECK-NEXT: xla_hlo.abs - // CHECK-NEXT: xla_hlo.abs - // CHECK-NEXT: xla_hlo.return + %0 = "mhlo.abs"(%arg0) : (tensor) -> tensor + %1 = "mhlo.abs"(%arg1) : (tensor) -> tensor + %2 = "mhlo.add"(%0, %1) : (tensor, tensor) -> tensor + %3 = "mhlo.abs"(%0) : (tensor) -> tensor + %4 = "mhlo.abs"(%1) : (tensor) -> tensor + // CHECK: %[[RET:.*]]:3 = "mhlo.fusion" + // CHECK-NEXT: mhlo.abs + // CHECK-NEXT: mhlo.abs + // CHECK-NEXT: mhlo.add + // CHECK-NEXT: mhlo.abs + // CHECK-NEXT: mhlo.abs + // CHECK-NEXT: mhlo.return return %2, %3, %4 : tensor, tensor, tensor } @@ -36,9 +36,9 @@ func @multi_outputs_same_2(%arg0: tensor, %arg1: tensor) -> (t // CHECK-LABEL: func @multi_outputs_not_sure_same func @multi_outputs_not_sure_same(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "xla_hlo.add"(%arg0, %arg0) : (tensor, tensor) -> tensor - // CHECK-NOT: xla_hlo.fusion - %1 = "xla_hlo.subtract"(%arg1, %arg1) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg0) : (tensor, tensor) -> tensor + // CHECK-NOT: mhlo.fusion + %1 = "mhlo.subtract"(%arg1, %arg1) : (tensor, tensor) -> tensor return %0, %1 : tensor, tensor } @@ -46,25 +46,25 @@ func @multi_outputs_not_sure_same(%arg0: tensor, %arg1: tensor // CHECK-LABEL: func @reduce func @reduce(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "xla_hlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor - // CHECK: %[[RET0:.*]] = "xla_hlo.fusion" - // CHECK-NEXT: xla_hlo.add - // CHECK-NEXT: xla_hlo.subtract - // CHECK-NEXT: xla_hlo.return + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "mhlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor + // CHECK: %[[RET0:.*]] = "mhlo.fusion" + // CHECK-NEXT: mhlo.add + // CHECK-NEXT: mhlo.subtract + // CHECK-NEXT: mhlo.return // Currently we do not support fuse arguments and ops without direct producer-consumer // relationship. Thus Reduce Op should not be fused with above two ops. - %2 = xla_hlo.constant dense<0.000000e+00> : tensor - %3 = "xla_hlo.reduce"(%arg0, %2) ( { + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.reduce"(%arg0, %2) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %4 = "xla_hlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor - "xla_hlo.return"(%4) : (tensor) -> () + %4 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "mhlo.return"(%4) : (tensor) -> () }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor - %4 = "xla_hlo.add"(%3, %3) : (tensor, tensor) -> tensor + %4 = "mhlo.add"(%3, %3) : (tensor, tensor) -> tensor // Above two ops should not be fused since reduce op can not be // fused with its consumer. - // CHECK-NOT: xla_hlo.fusion + // CHECK-NOT: mhlo.fusion return %1, %4 : tensor, tensor } @@ -73,25 +73,25 @@ func @reduce(%arg0: tensor, %arg1: tensor) -> (tensor // CHECK-LABEL: func @reduce_2 func @reduce_2(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "xla_hlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "mhlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor - %2 = xla_hlo.constant dense<0.000000e+00> : tensor - %3 = "xla_hlo.reduce"(%1, %2) ( { + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.reduce"(%1, %2) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %4 = "xla_hlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor - "xla_hlo.return"(%4) : (tensor) -> () + %4 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "mhlo.return"(%4) : (tensor) -> () }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor - // CHECK: %[[RET0:.*]]:2 = "xla_hlo.fusion" - // CHECK-NEXT: xla_hlo.add - // CHECK-NEXT: xla_hlo.subtract - // CHECK-NEXT: xla_hlo.constant - // CHECK-NEXT: xla_hlo.reduce - // CHECK: xla_hlo.return + // CHECK: %[[RET0:.*]]:2 = "mhlo.fusion" + // CHECK-NEXT: mhlo.add + // CHECK-NEXT: mhlo.subtract + // CHECK-NEXT: mhlo.constant + // CHECK-NEXT: mhlo.reduce + // CHECK: mhlo.return // Following op should not be fused with the above ops since reduce op can not be // fused with its consumer. - // CHECK-NOT: xla_hlo.fusion - %4 = "xla_hlo.add"(%3, %3) : (tensor, tensor) -> tensor + // CHECK-NOT: mhlo.fusion + %4 = "mhlo.add"(%3, %3) : (tensor, tensor) -> tensor return %1, %4 : tensor, tensor } diff --git a/tests/xla-transform-unranked-hlo.mlir b/tests/xla-transform-unranked-hlo.mlir index eb98789..8047415 100644 --- a/tests/xla-transform-unranked-hlo.mlir +++ b/tests/xla-transform-unranked-hlo.mlir @@ -9,15 +9,15 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { %num_elements = shape.num_elements %shape %num_elements_as_index = shape.size_to_index %num_elements %flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex> - %flat_a = "xla_hlo.dynamic_reshape"(%a, %flat_shape) + %flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor // Apply operation. - %flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor) -> tensor + %flat_b = "mhlo.sqrt"(%flat_a) : (tensor) -> tensor // Restore original shape. %shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor - %b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) + %b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) : (tensor, tensor) -> tensor<*xf32> return %b : tensor<*xf32> @@ -33,12 +33,12 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] // CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> - // CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor + // CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor // CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor - // CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: return %[[B]] : tensor<*xf32> - %b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> + %b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> return %b : tensor<*xf32> } @@ -48,9 +48,9 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @sqrt_ranked // CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>) func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> { - // CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32> + // CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32> // CHECK-NEXT: return %[[B]] : tensor<3x?xf32> - %b = "xla_hlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32> + %b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32> return %b : tensor<3x?xf32> } @@ -60,9 +60,9 @@ func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> { // CHECK-LABEL: @sqrt_static // CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>) func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: return %[[B]] : tensor<2x3xf32> - %b = "xla_hlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32> + %b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32> return %b : tensor<2x3xf32> } @@ -77,12 +77,12 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] // CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> - // CHECK: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[FLAT_B:.*]] = "xla_hlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[FLAT_RESULT:.*]] = xla_hlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor + // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor // CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> // CHECK: return %[[RESULT]] : tensor<*xf32> - %result = xla_hlo.add %a, %b : tensor<*xf32> + %result = mhlo.add %a, %b : tensor<*xf32> return %result : tensor<*xf32> }