From 7c4a5d62b53d95f1c2e860ade39d3aa473deac19 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Wed, 8 Jul 2020 17:05:32 +0000 Subject: [PATCH] Rename xla_lhlo dialect into lmhlo Following on the plan of isolating the compiler/mlir/hlo directory. Another xla_lhlo dialect will be created under compiler/mlir/xla/ later. PiperOrigin-RevId: 320210326 --- include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h | 10 +- include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 10 +- .../mhlo/transforms/map_hlo_to_lhlo_op.h | 2 +- .../mhlo/transforms/map_xla_to_scalar_op.h | 186 ++++++------ .../mlir-hlo/Dialect/mhlo/transforms/passes.h | 4 +- .../Dialect/mhlo/transforms/rewriters.h | 4 +- lib/Dialect/mhlo/IR/dialect_registration.cc | 2 +- lib/Dialect/mhlo/IR/lhlo_ops.cc | 6 +- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 46 +-- .../mhlo/transforms/lhlo_copy_removal.cc | 6 +- .../mhlo/transforms/lhlo_fuse_linalg.cc | 4 +- .../transforms/lhlo_legalize_to_affine.cc | 22 +- .../mhlo/transforms/lhlo_legalize_to_gpu.cc | 6 +- .../mhlo/transforms/lhlo_legalize_to_llvm.cc | 4 +- .../transforms/lhlo_legalize_to_llvm_pass.cc | 6 +- .../lhlo_legalize_to_parallel_loops.cc | 62 ++-- .../mhlo/transforms/xla_legalize_to_linalg.cc | 120 ++++---- tests/hlo-legalize-to-lhlo.mlir | 78 ++--- tests/lhlo-copy-removal.mlir | 68 ++--- tests/lhlo-legalize-select-and-scatter.mlir | 16 +- tests/lhlo-legalize-to-affine.mlir | 32 +-- tests/lhlo-legalize-to-gpu.mlir | 8 +- tests/lhlo-legalize-to-linalg.mlir | 108 +++---- tests/lhlo-legalize-to-llvm.mlir | 4 +- tests/lhlo-legalize-to-parallel-loops.mlir | 32 +-- tests/lhlo_ops.mlir | 270 +++++++++--------- 26 files changed, 566 insertions(+), 550 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index 0ea62ba..2aff525 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -35,18 +35,18 @@ class OpBuilder; #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc" -namespace xla_lhlo { +namespace lmhlo { -class XlaLhloDialect : public Dialect { +class LmhloDialect : public Dialect { public: - explicit XlaLhloDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "xla_lhlo"; } + explicit LmhloDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "lmhlo"; } }; #define GET_OP_CLASSES #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" -} // namespace xla_lhlo +} // namespace lmhlo } // end namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 2af7c44..167df89 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -38,8 +38,8 @@ include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInte include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" def LHLO_Dialect : Dialect { - let name = "xla_lhlo"; - let cppNamespace = "xla_lhlo"; + let name = "lmhlo"; + let cppNamespace = "lmhlo"; } //===----------------------------------------------------------------------===// @@ -253,7 +253,7 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ // TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example, // A tuple-like pattern match syntax could work: -// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) { +// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) { // ... // }, { // ... @@ -337,7 +337,7 @@ def HLO_StaticMemRefCastOp: Op -> memref<5xf32, offset: 2, strides: [1]> // The result of the op is a rank-1 memref with `[5]` shape, stride 1 and @@ -379,7 +379,7 @@ def HLO_DynamicMemRefCastOp: Op -> memref // The result of the op is a type-erased memref with `[%size_X, %size_Y]` // shape and `[%step_X, %step_Y]` strides. The offset will be inherited diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index a05d1d3..9e7126e 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -34,7 +34,7 @@ using HloToLhloOp = typename HloToLhloOpImpl::Type; #define MAP_HLO_TO_LHLO(OpName) \ template <> \ struct HloToLhloOpImpl { \ - using Type = xla_lhlo::OpName; \ + using Type = lmhlo::OpName; \ } MAP_HLO_TO_LHLO(AbsOp); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h index 16a31f9..be06237 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h @@ -24,7 +24,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace impl { // A struct to map LhloBinaryOpTy type to the corresponding floating-point and @@ -33,32 +33,32 @@ template struct LhloToScalarOp; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::AddFOp; using IOp = ::mlir::AddIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::CmpFOp; using IOp = ::mlir::CmpIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::DivFOp; using IOp = ::mlir::SignedDivIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::MulFOp; using IOp = ::mlir::MulIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::RemFOp; using IOp = ::mlir::SignedRemIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::SubFOp; using IOp = ::mlir::SubIOp; }; @@ -116,16 +116,17 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { Type element_type = args.front().getType(); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } if (element_type.isa()) { - // xla_lhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) + // lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); @@ -133,16 +134,17 @@ inline Value MapLhloOpToStdScalarOp( b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); auto lhs_gt_zero = b->create>(loc, CmpIPredicate::sge, lhs, zero_intval); - auto neg_val = b->create>(loc, zero_intval, lhs); + auto neg_val = b->create>(loc, zero_intval, lhs); return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val); } return nullptr; } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -205,30 +207,33 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc, } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return args.front(); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( +inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, @@ -236,21 +241,23 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( +inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { Type sourceType = args.front().getType(); @@ -288,9 +295,10 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { // Dot Op converter from lhlo to affine only accepts float and integer types. const auto& lhs = args[0]; const auto& rhs = args[1]; @@ -312,17 +320,19 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -361,66 +371,69 @@ struct XlaCompareSelectOpToStdScalarOp -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return XlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>::map(loc, "GT", - result_types, args, - b); + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "GT", result_types, + args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return XlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>::map(loc, "LT", - result_types, args, - b); + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "LT", result_types, + args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { Type element_type = args.front().getType(); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } if (element_type.isa()) { - // xla_lhlo.neg(x, result) -> result = sub(0, x) + // lmhlo.neg(x, result) -> result = sub(0, x) Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); auto zero_intval = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); - return b->create>(loc, zero_intval, lhs); + return b->create>(loc, zero_intval, lhs); } return nullptr; } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( +inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, @@ -428,9 +441,10 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { Type element_type = args.front().getType(); if (element_type.isa()) { FloatType float_type = element_type.cast(); @@ -442,17 +456,19 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -460,10 +476,10 @@ inline Value MapLhloOpToStdScalarOp( } // namespace impl struct XlaOpToStdScalarOp { - // Implementation for LHLO ops except xla_lhlo::CompareOp. + // Implementation for LHLO ops except lmhlo::CompareOp. template ::value && + !std::is_same::value && std::is_same, std::false_type>::value>> static Value map(XlaOpTy op, ArrayRef result_types, @@ -475,7 +491,7 @@ struct XlaOpToStdScalarOp { // Implementation for HLO ops except mhlo::CompareOp. template , typename = std::enable_if_t< - !std::is_same::value && + !std::is_same::value && !std::is_same::value>> static Value map(XlaOpTy op, ArrayRef result_types, ArrayRef args, OpBuilder* b, int i = 0) { @@ -483,13 +499,13 @@ struct XlaOpToStdScalarOp { args, b); } - // Implementation for xla_lhlo::CompareOp. + // Implementation for lmhlo::CompareOp. template ::value>> - static Value map(xla_lhlo::CompareOp op, ArrayRef result_types, + LhloOpTy, lmhlo::CompareOp>::value>> + static Value map(lmhlo::CompareOp op, ArrayRef result_types, ArrayRef args, OpBuilder* b) { auto comparison_direction = op.comparison_direction(); - return impl::MapXlaCompareOpToStdScalarOp( + return impl::MapXlaCompareOpToStdScalarOp( op.getLoc(), comparison_direction, result_types, args, b); } @@ -500,12 +516,12 @@ struct XlaOpToStdScalarOp { static Value map(mhlo::CompareOp op, ArrayRef result_types, ArrayRef args, OpBuilder* b) { auto comparison_direction = op.comparison_direction(); - return impl::MapXlaCompareOpToStdScalarOp( + return impl::MapXlaCompareOpToStdScalarOp( op.getLoc(), comparison_direction, result_types, args, b); } }; -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index b279e15..aa06493 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -60,7 +60,7 @@ std::unique_ptr> createXlaHloFusionPass(); } // namespace mhlo -namespace xla_lhlo { +namespace lmhlo { // Lowers from LHLO dialect to Affine dialect. std::unique_ptr> createLegalizeToAffinePass(); @@ -92,7 +92,7 @@ std::unique_ptr createLhloCopyRemovalPass(); // Lowers from LHLO dialect to parallel loops. std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); -} // namespace xla_lhlo +} // namespace lmhlo namespace xla { diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index fd0cc89..a3e36b6 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -75,14 +75,14 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context, } // namespace mhlo -namespace xla_lhlo { +namespace lmhlo { /// Collect a set of patterns to convert from the LHLO dialect to LLVM. void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, LLVMTypeConverter *converter, OwningRewritePatternList *patterns); -} // namespace xla_lhlo +} // namespace lmhlo namespace xla_chlo { diff --git a/lib/Dialect/mhlo/IR/dialect_registration.cc b/lib/Dialect/mhlo/IR/dialect_registration.cc index 5e45b51..7c3a8ec 100644 --- a/lib/Dialect/mhlo/IR/dialect_registration.cc +++ b/lib/Dialect/mhlo/IR/dialect_registration.cc @@ -21,4 +21,4 @@ limitations under the License. static mlir::DialectRegistration mhlo_ops; static mlir::DialectRegistration xla_chlo_ops; -static mlir::DialectRegistration xla_lhlo_ops; +static mlir::DialectRegistration lmhlo_ops; diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 305df4f..0c60c6f 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -46,9 +46,9 @@ limitations under the License. namespace mlir { #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc" -namespace xla_lhlo { +namespace lmhlo { -XlaLhloDialect::XlaLhloDialect(MLIRContext *context) +LmhloDialect::LmhloDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST @@ -138,5 +138,5 @@ void FusionOp::build(OpBuilder &builder, OperationState &result, FusionOp::ensureTerminator(*bodyRegion, builder, result.location); } -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 187b145..fb5dfc4 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -44,7 +44,7 @@ template using BaseOpConversion = BufferAssignmentOpConversionPattern; using StdReturnOpConverter = detail::BufferAssignmentReturnOpConverter; + lmhlo::CopyOp, true>; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, @@ -149,7 +149,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter Value transformed_operand = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); - rewriter.create( + rewriter.create( loc, transformed_operand, resultBuffer, op.broadcast_dimensions()); rewriter.replaceOp(op, {resultBuffer}); @@ -161,7 +161,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter // Inserts dynamic memref to change the layout of the memref to put 0-stride // and size of the target dimension if size-1 dimension expansion is // necessary. - xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( + lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { auto loc = op.getLoc(); auto operand_type = operand.getType().cast(); @@ -214,7 +214,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter makeStridedLinearLayoutMap(dynamic_layout, /*offset=*/0, b->getContext())); - auto transformed_operand = b->create( + auto transformed_operand = b->create( loc, type_erased_memref_type, operand, sizes, strides); return transformed_operand; } @@ -239,7 +239,7 @@ struct HloToLhloDynamicReshapeConverter return failure(); } mhlo::DynamicReshapeOp::Adaptor adaptor(operands); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, adaptor.operand(), adaptor.output_shape()); return success(); } @@ -266,8 +266,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { buffer_args.push_back( InsertAlloc(loc, result, this->bufferAssignment, &rewriter)); } - auto new_op = rewriter.create( - loc, llvm::None, buffer_args, op.getAttrs()); + auto new_op = rewriter.create(loc, llvm::None, buffer_args, + op.getAttrs()); // Copy over the operations inside the region. rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end()); @@ -292,7 +292,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { } // Insert terminator at the end. rewriter.setInsertionPointToEnd(&entry_block); - rewriter.create(loc); + rewriter.create(loc); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); @@ -321,8 +321,8 @@ class HloToLhloTensorStoreOpConverter LogicalResult matchAndRewrite( mlir::TensorStoreOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { - rewriter.replaceOpWithNewOp( - op, llvm::None, operands.front(), operands.back()); + rewriter.replaceOpWithNewOp(op, llvm::None, operands.front(), + operands.back()); return success(); } }; @@ -336,7 +336,7 @@ class HloToLhloTensorStoreOpConverter // %arg1: memref<2x2xf32>, // %arg2: memref<2x2xf32>, // %arg3: memref<2x2xf32>) { -// "xla_lhlo.fusion"() ({ +// "lmhlo.fusion"() ({ // %0 = tensor_load %arg1 : memref<2x2xf32> // %1 = tensor_load %arg2 : memref<2x2xf32> // %2 = "mhlo.add"(%0, %1) : @@ -345,7 +345,7 @@ class HloToLhloTensorStoreOpConverter // %4 = "mhlo.multiply"(%2, %3) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // tensor_store %4, %arg3 : memref<2x2xf32> -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.terminator"() : () -> () // }) : () -> () // return // } @@ -355,13 +355,13 @@ class HloToLhloTensorStoreOpConverter // %arg1: memref<2x2xf32>, // %arg2: memref<2x2xf32>, // %arg3: memref<2x2xf32>) { -// "xla_lhlo.fusion"() ( { +// "lmhlo.fusion"() ( { // %0 = alloc() : memref<2x2xf32> -// "xla_lhlo.add"(%arg1, %arg2, %0) : +// "lmhlo.add"(%arg1, %arg2, %0) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -// "xla_lhlo.multiply"(%0, %arg0, %arg3) : +// "lmhlo.multiply"(%0, %arg0, %arg3) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.terminator"() : () -> () // }) : () -> () // return // } @@ -382,13 +382,13 @@ class HloToLhloTensorStoreOpConverter // %arg2: memref<4xf32>) { // %0 = alloc() : memref<4xf32> -// "xla_lhlo.maximum"(%arg0, %arg1, %0) : +// "lmhlo.maximum"(%arg0, %arg1, %0) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () // %1 = alloc() : memref<4xf32> -// "xla_lhlo.add"(%arg0, %0, %1) : +// "lmhlo.add"(%arg0, %0, %1) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () -// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () +// "lmhlo.terminator"() : () -> () // } struct HloLegalizeToLhlo @@ -406,7 +406,7 @@ struct HloLegalizeToLhlo OwningRewritePatternList patterns; auto& context = getContext(); ConversionTarget target(context); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); target.addIllegalOp(); @@ -441,12 +441,12 @@ struct HloLegalizeToLhlo &converter, &patterns); if (results_escape_function) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp, + mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp, /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment, &converter, &patterns); } else { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp, + mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp, /*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment, &converter, &patterns); } diff --git a/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc b/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc index 4fbd774..3310170 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc @@ -23,7 +23,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // Removes LHLO copy operations that copy from allocated buffers to block @@ -34,7 +34,7 @@ struct LhloCopyRemoval : mlir::PassWrapper> { void runOnOperation() override { llvm::SmallVector eraseList; auto operation = getOperation(); - operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) { + operation->walk([&](mlir::lmhlo::CopyOp copyOp) { // If this region contains more than one block, then ignore this copy // operation. if (copyOp.getParentRegion()->getBlocks().size() > 1) { @@ -101,5 +101,5 @@ std::unique_ptr createLhloCopyRemovalPass() { static PassRegistration copy_removal_pass( "lhlo-copy-removal", "Removes redundant LHLO copy operations"); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index c5b81ec..01aba61 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -27,7 +27,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { using linalg::LinalgOp; @@ -147,5 +147,5 @@ static PassRegistration legalize_pass( "lhlo-fuse-linalg", "Greedily fuse linalg ops obtained after LHLO lowering."); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index f4354d1..7971240 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -28,7 +28,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // Builds an affine loop nest iterating from zeros to "upper_bounds" with unit @@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern { auto r = builder.create(loc, rhs, rhs_indices); auto result = rewriter.create(loc, op.output(), result_indices); - Value op_result = xla_lhlo::XlaOpToStdScalarOp::map( + Value op_result = lmhlo::XlaOpToStdScalarOp::map( op, element_type, {l, r, result}, &builder); map_status = success(op_result != nullptr); if (failed(map_status)) return; @@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern { ValueRange induction_vars) { auto l = builder.create(loc, lhs, induction_vars); auto r = builder.create(loc, rhs, induction_vars); - Value op_result = xla_lhlo::XlaOpToStdScalarOp::map( + Value op_result = lmhlo::XlaOpToStdScalarOp::map( op, element_type, {l, r}, &builder); map_status = success(op_result != nullptr); if (failed(map_status)) return; @@ -127,13 +127,13 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, DotOpConverter>(context); // clang-format on } @@ -157,5 +157,5 @@ std::unique_ptr> createLegalizeToAffinePass() { static PassRegistration legalize_pass( "lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect"); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index bb502ad..dbae1e6 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -38,7 +38,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // A simple translation of LHLO reduce operations to a corresponding gpu @@ -173,7 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); + gpu::GPUDialect, scf::SCFDialect, LmhloDialect>(); target.addIllegalOp(); auto func = getFunction(); patterns.insert(func.getContext()); @@ -192,5 +192,5 @@ std::unique_ptr> createLegalizeToGpuPass() { static PassRegistration legalize_pass( "lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect"); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc index bfd0148..7d83589 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc @@ -21,7 +21,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { struct StaticMemRefCastOpConverter @@ -132,5 +132,5 @@ void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, *converter, options); } -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc index 0fa52c0..5313658 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc @@ -23,7 +23,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { class TestLhloToLLVMPass @@ -42,7 +42,7 @@ class TestLhloToLLVMPass ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); - target.addIllegalDialect(); + target.addIllegalDialect(); if (failed(applyFullConversion(m, target, patterns))) { signalPassFailure(); @@ -55,5 +55,5 @@ class TestLhloToLLVMPass static PassRegistration legalize_lhlo_pass( "test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM."); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index cb2451f..b9f4ad4 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -26,7 +26,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // Clones and adapts the code in `lhlo_block` that works on buffers and has a @@ -154,14 +154,14 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, return b->create(loc, lower, upper, step); } -// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp. +// Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops if there are // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` // contains the reduction operator. // // Example: // -// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( { +// "lmhlo.reduce"(%buffer, %init_buf, %result) ( { // ^bb0(%lhs: memref, %rhs: memref, %res: memref): // // } ) {dimensions = dense<[1]> : tensor<1xi64>} @@ -187,12 +187,12 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, // } : f32 // scf.yield // } -class ReduceOpConverter : public OpConversionPattern { +class ReduceOpConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::ReduceOp xla_reduce_op, ArrayRef /*args*/, + lmhlo::ReduceOp xla_reduce_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { // TODO(b/137624192) Implement variadic reduce. if (xla_reduce_op.out().size() != 1) return failure(); @@ -226,7 +226,7 @@ class ReduceOpConverter : public OpConversionPattern { // scf.yield // } scf::ReduceOp CreateReduceOpInNestedParallelLoops( - xla_lhlo::ReduceOp xla_reduce_op, + lmhlo::ReduceOp xla_reduce_op, ConversionPatternRewriter* rewriter) const { auto loc = xla_reduce_op.getLoc(); DenseSet reducing_dims; @@ -314,7 +314,7 @@ class ReduceOpConverter : public OpConversionPattern { // accumulator = reduction_operator(output[O], value) // output[O] = accumulator // -// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a +// Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a // scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops that traverese output // buffer. The inner `ParalleOp` refers to the reduction loops that traverse @@ -325,11 +325,11 @@ class ReduceOpConverter : public OpConversionPattern { // func @reduce_window(%arg: memref<112x112xf32>, // %init: memref, // %result: memref<56x56xf32>) { -// "xla_lhlo.reduce_window"(%arg, %init, %result) ( { +// "lmhlo.reduce_window"(%arg, %init, %result) ( { // ^bb0(%lhs: memref, %rhs: memref, %res: memref): -// "xla_lhlo.maximum"(%lhs, %rhs, %res) +// "lmhlo.maximum"(%lhs, %rhs, %res) // : (memref, memref, memref) -> () -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.terminator"() : () -> () // }) { // padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, // window_dimensions = dense<[3, 3]> : tensor<2xi64>, @@ -359,12 +359,12 @@ class ReduceOpConverter : public OpConversionPattern { // return // } class ReduceWindowOpConverter - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, + lmhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { scf::ParallelOp output_loop, window_loop; std::tie(output_loop, window_loop) = @@ -383,7 +383,7 @@ class ReduceWindowOpConverter private: std::pair CreateParallelLoopsToTraverseOutputAndWindow( - xla_lhlo::ReduceWindowOp xla_reduce_window_op, + lmhlo::ReduceWindowOp xla_reduce_window_op, ConversionPatternRewriter* rewriter) const { auto loc = xla_reduce_window_op.getLoc(); Value init_value = @@ -415,9 +415,8 @@ class ReduceWindowOpConverter } scf::ReduceOp CreateReduceOpInNestedParallelLoops( - xla_lhlo::ReduceWindowOp xla_reduce_window_op, - scf::ParallelOp output_loop, scf::ParallelOp window_loop, - ConversionPatternRewriter* rewriter) const { + lmhlo::ReduceWindowOp xla_reduce_window_op, scf::ParallelOp output_loop, + scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const { rewriter->setInsertionPointToStart(window_loop.getBody()); auto loc = xla_reduce_window_op.getLoc(); @@ -481,12 +480,12 @@ class ReduceWindowOpConverter // initialized_flag = true // output(selected_index) = scatter(output(selected_index), source(S)) class SelectAndScatterOpConverter - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef /*args*/, + lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { auto loc = s_and_s_op.getLoc(); InitializeOutput(s_and_s_op, &rewriter); @@ -515,7 +514,7 @@ class SelectAndScatterOpConverter } private: - void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op, + void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); Value init_value = b->create(loc, s_and_s_op.init_value()); @@ -533,7 +532,7 @@ class SelectAndScatterOpConverter SmallVector window_ivs; scf::ForOp inner_loop; }; - WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op, + WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op, scf::ParallelOp loop_over_src, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); @@ -598,7 +597,7 @@ class SelectAndScatterOpConverter SmallVector ivs_val_flag_; }; - SmallVector SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op, + SmallVector SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op, scf::ParallelOp loop_over_src, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); @@ -636,9 +635,10 @@ class SelectAndScatterOpConverter return window_loops.selected_ivs; } - SmallVector SelectOrInitialize( - xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef operand_ivs, - IterArgs* ivs_val_flag, OpBuilder* b) const { + SmallVector SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op, + ArrayRef operand_ivs, + IterArgs* ivs_val_flag, + OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); Value true_i1 = b->create( loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1)); @@ -707,9 +707,9 @@ struct LhloLegalizeToParallelLoops ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + scf::SCFDialect, LmhloDialect>(); + target.addIllegalOp(); if (failed(applyPartialConversion(func, target, patterns))) { signalPassFailure(); @@ -727,5 +727,5 @@ static PassRegistration legalize_lhlo_pass( "lhlo-legalize-to-parallel-loops", "Legalize from LHLO dialect to parallel loops."); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc index 9dd69b8..ccecadd 100644 --- a/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc @@ -131,9 +131,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern { loc, opResultTypes, args, args_count, results_count, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { - // TODO(ravishankarm) : For now use the method in xla_lhlo namespace. + // TODO(ravishankarm) : For now use the method in lmhlo namespace. // That method needs to be moved out of there. - Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + Value opResult = lmhlo::XlaOpToStdScalarOp::map( op, bodyResultTypes, llvm::to_vector<2>(args.take_front(args_count)), &rewriter); nestedBuilder.create(loc, opResult); @@ -162,8 +162,8 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { // Create two loads from the input. auto lhs = rewriter.create(loc, lhlo_op.lhs()); auto rhs = rewriter.create(loc, lhlo_op.rhs()); - // TODO(ravishankarm) : Move this method out of xla_lhlo namespace. - Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + // TODO(ravishankarm) : Move this method out of lmhlo namespace. + Value opResult = lmhlo::XlaOpToStdScalarOp::map( lhlo_op, argType.getElementType(), llvm::ArrayRef{lhs, rhs}, &rewriter); rewriter.create(loc, opResult, lhlo_op.out()); @@ -173,21 +173,21 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { }; //===----------------------------------------------------------------------===// -// xla_lhlo.convolution conversion pattern. +// lmhlo.convolution conversion pattern. //===----------------------------------------------------------------------===// -/// Converts xla_lhlo.convolution operation to a linalg.conv op. -struct ConvToLinalgConverter : public OpConversionPattern { +/// Converts lmhlo.convolution operation to a linalg.conv op. +struct ConvToLinalgConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; // This code has been adapted from IREE's // (https://github.com/google/iree/) mhlo -> linalg conversion. LogicalResult matchAndRewrite( - xla_lhlo::ConvOp op, ArrayRef args, + lmhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { // Check validity of dimension information. - if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers = + if (const lmhlo::ConvDimensionNumbers& dimensionNumbers = op.dimension_numbers()) { const int inputSpatialRank = llvm::size(dimensionNumbers.input_spatial_dimensions()); @@ -388,14 +388,14 @@ class HloBroadcastInDimConverter }; class LhloBroadcastInDimConverter - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::BroadcastInDimOp op, ArrayRef args, + lmhlo::BroadcastInDimOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); + lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); auto result_type = operand_adaptor.output().getType().cast(); auto result_shape = result_type.getShape(); @@ -444,9 +444,9 @@ class LhloBroadcastInDimConverter // Inserts 'linalg.reshape' if there is a size-1 dim expansion. std::pair> InsertReshapeIfNecessary( - xla_lhlo::BroadcastInDimOp op, ArrayRef args, + lmhlo::BroadcastInDimOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const { - xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); + lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); Value operand = operand_adaptor.operand(); auto operand_type = operand_adaptor.operand().getType().cast(); auto operand_shape = operand_type.getShape(); @@ -512,7 +512,7 @@ class LhloBroadcastInDimConverter return std::make_pair(operand, broadcast_dims); } - SmallVector getIndexingMaps(xla_lhlo::BroadcastInDimOp op, + SmallVector getIndexingMaps(lmhlo::BroadcastInDimOp op, ArrayRef broadcastDims, ArrayRef resultShape, MemRefType operandType, @@ -639,12 +639,12 @@ class ReshapeOpConverter : public OpConversionPattern { } }; -class IotaConverter : public OpConversionPattern { +class IotaConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::IotaOp iotaOp, ArrayRef args, + lmhlo::IotaOp iotaOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto resultMemrefType = iotaOp.getOperand().getType().dyn_cast(); @@ -680,12 +680,12 @@ class IotaConverter : public OpConversionPattern { } }; -class ConstConverter : public OpConversionPattern { +class ConstConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::ConstOp constOp, ArrayRef args, + lmhlo::ConstOp constOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = constOp.getLoc(); auto valueAttr = constOp.value().cast(); @@ -726,12 +726,12 @@ class ReverseConverter } }; -class SliceConverter : public OpConversionPattern { +class SliceConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::SliceOp sliceOp, ArrayRef args, + lmhlo::SliceOp sliceOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = sliceOp.getLoc(); auto argType = @@ -763,50 +763,50 @@ class SliceConverter : public OpConversionPattern { void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off - patterns->insert, + patterns->insert, ConstConverter, ConvToLinalgConverter, IotaConverter, LhloBroadcastInDimConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, // TODO(ataei): Remove this pattern, CopyOp is folded away. - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - ReshapeOpConverter, - ReverseConverter, - ScalarPointwiseToStandardConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + ReshapeOpConverter, + ReverseConverter, + ScalarPointwiseToStandardConverter, SliceConverter >(context); // clang-format on } // Converts LHLO ops to Linalg generic. -// Sample result for xla_lhlo::AddOp. +// Sample result for lmhlo::AddOp. // -// "xla_lhlo.add"(%arg1, %arg2, %out) : +// "lmhlo.add"(%arg1, %arg2, %out) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // // will be converted to @@ -854,14 +854,14 @@ struct HloLegalizeToLinalg } // namespace -namespace xla_lhlo { +namespace lmhlo { std::unique_ptr> createLegalizeLhloToLinalgPass() { return absl::make_unique(); } static PassRegistration legalize_lhlo_pass( "lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect"); -} // namespace xla_lhlo +} // namespace lmhlo namespace mhlo { diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index a555935..aa5d800 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -7,7 +7,7 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_result = "mhlo.exponential"(%tensor_operand) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} + // BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -18,10 +18,10 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> { return %arg0 : tensor<4xf32> } // PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) -// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () +// PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () // PRE-NEXT: return // ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] -// ESC-NOT: "xla_lhlo.copy" +// ESC-NOT: "lmhlo.copy" // ESC-NEXT: return %[[ARG0]] // ----- @@ -38,20 +38,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> // PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) // ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32> // BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) +// BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) // BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) +// BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) // BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> // BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) +// BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) // BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> -//  BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) +//  BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) // BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) +// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) // BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> -// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () +// PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () // PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> // PRE-NEXT: return // ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32> @@ -67,14 +67,14 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> %sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) + // BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> %tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) + // BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> - // BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) + // BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) tensor_store %tensor_result, %result : memref<2x2xf32> // BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> // BOTH-NEXT: return @@ -88,7 +88,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.copy"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -100,7 +100,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.exponential"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -112,7 +112,7 @@ func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.log"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.log"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -127,7 +127,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + // BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -141,7 +141,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x %tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs) {comparison_direction = "EQ"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> - // BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} + // BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} tensor_store %tensor_result, %result : memref<2x2xi1> return } @@ -154,7 +154,7 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { %tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<10x5xf32> - // BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<10x5xf32> return } @@ -205,12 +205,12 @@ func @dyn_broadcast(%operand: memref) { // BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] // BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index - // BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast + // BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast // BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]) // BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] // BOTH-SAME: : memref -> memref - // BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { + // BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { // BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> // BOTH-SAME: } : (memref, memref) -> () @@ -229,7 +229,7 @@ func @complex(%real: memref<2x2xf32>, %tensor_imag = tensor_load %imag : memref<2x2xf32> %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex> - // BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xcomplex> return } @@ -241,7 +241,7 @@ func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> %tensor_result = "mhlo.real"(%tensor_operand) : (tensor<2x2xcomplex>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.real"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -253,7 +253,7 @@ func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> %tensor_result = "mhlo.imag"(%tensor_operand) : (tensor<2x2xcomplex>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -264,7 +264,7 @@ func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { func @iota(%result: memref<10xi32>) { %tensor_result = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<10xi32> - // BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} + // BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} tensor_store %tensor_result, %result : memref<10xi32> return } @@ -276,7 +276,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.abs"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -288,7 +288,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.ceil"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -300,7 +300,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.convert"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}}) // BOTH-NOT: tensor_store tensor_store %tensor_result, %result : memref<2x2xf32> return @@ -313,7 +313,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.cosine"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -325,7 +325,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.negate"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -337,7 +337,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.rsqrt"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -349,7 +349,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.sign"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -361,7 +361,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.sqrt"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -373,7 +373,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.tanh"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -386,7 +386,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) + // BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -412,7 +412,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) { // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) - // BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () + // BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () return } @@ -437,7 +437,7 @@ func @tanh_dyn(%arg0: tensor) { // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) - // BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () + // BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () return } @@ -448,10 +448,10 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { // PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // BOTH-NEXT: %[[ALLOC:.*]] = alloc -// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () %dot = "mhlo.dot"(%arg0, %arg0) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> -// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]]) +// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]]) // ESC: return %[[ALLOC]] return %dot : tensor<1024x1024xf32> } @@ -462,7 +462,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> { %c0 = constant 0 : index // BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> - // BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) + // BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) // BOTH-SAME: padding = dense<[ // BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> // BOTH-SAME: rhs_dilation = dense<[1, 2]> diff --git a/tests/lhlo-copy-removal.mlir b/tests/lhlo-copy-removal.mlir index 3d3f802..6d7992c 100644 --- a/tests/lhlo-copy-removal.mlir +++ b/tests/lhlo-copy-removal.mlir @@ -3,10 +3,10 @@ // CHECK-LABEL: func @remove_simple func @remove_simple(%arg0: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- @@ -14,9 +14,9 @@ func @remove_simple(%arg0: memref<2x2xf32>) { // CHECK-LABEL: func @remove_without_dealloc func @remove_without_dealloc(%arg0: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- @@ -24,22 +24,22 @@ func @remove_without_dealloc(%arg0: memref<2x2xf32>) { // CHECK-LABEL: func @replace_dependency func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- // CHECK-LABEL: func @keep_copies func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { - // CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () + // CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- @@ -50,14 +50,14 @@ func @must_not_be_removed(%arg0: memref<2x2xf32>, %arg2: memref<2x2xf32>) { // CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32> %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- @@ -67,13 +67,13 @@ func @must_be_removed_first(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>, %arg2: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- @@ -83,11 +83,11 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>, %arg2: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } diff --git a/tests/lhlo-legalize-select-and-scatter.mlir b/tests/lhlo-legalize-select-and-scatter.mlir index 2aa6378..b110d8d 100644 --- a/tests/lhlo-legalize-select-and-scatter.mlir +++ b/tests/lhlo-legalize-select-and-scatter.mlir @@ -10,18 +10,18 @@ func @select_and_scatter(%arg: memref<112x112xf32>, %src: memref<56x56xf32>, %init: memref, %result: memref<112x112xf32>) { - "xla_lhlo.select_and_scatter"(%arg, %src, %init, %result) ( { + "lmhlo.select_and_scatter"(%arg, %src, %init, %result) ( { // select ^bb0(%lhs: memref, %rhs: memref, %pred: memref): - "xla_lhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} : + "lmhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () }, { // scatter ^bb0(%lhs: memref, %rhs: memref, %out: memref): - "xla_lhlo.add"(%lhs, %rhs, %out) : + "lmhlo.add"(%lhs, %rhs, %out) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () }) { padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, window_dimensions = dense<[3, 3]> : tensor<2xi64>, @@ -29,7 +29,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, } : (memref<112x112xf32>, memref<56x56xf32>, memref, memref<112x112xf32>) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // CHECK-LABEL: func @select_and_scatter( // CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>, @@ -121,7 +121,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref // Compute PRED. - // CHECK: "xla_lhlo.compare"( + // CHECK: "lmhlo.compare"( // CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]]) // CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref @@ -182,7 +182,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref // Compute scatter value. -// CHECK: "xla_lhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) : +// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) : // CHECK-SAME: (memref, memref, memref) -> () // CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref diff --git a/tests/lhlo-legalize-to-affine.mlir b/tests/lhlo-legalize-to-affine.mlir index 1068d1a..8781804 100644 --- a/tests/lhlo-legalize-to-affine.mlir +++ b/tests/lhlo-legalize-to-affine.mlir @@ -14,7 +14,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, // CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> // CHECK: return - "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : + "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> () return } @@ -24,7 +24,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: addf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"} + "lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -32,7 +32,7 @@ func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: addi %{{.*}}, %{{.*}} : i32 - "xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"} + "lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -42,7 +42,7 @@ func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: and %{{.*}}, %{{.*}} : i32 - "xla_lhlo.and"(%lhs, %rhs, %result) {name = "and.1"} + "lmhlo.and"(%lhs, %rhs, %result) {name = "and.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -52,7 +52,7 @@ func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: divf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} + "lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -60,7 +60,7 @@ func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: divi_signed %{{.*}}, %{{.*}} : i32 - "xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} + "lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -71,7 +71,7 @@ func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 - "xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} + "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -81,7 +81,7 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 - "xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} + "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -92,7 +92,7 @@ func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 - "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} + "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -102,7 +102,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 - "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} + "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -112,7 +112,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: mulf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} + "lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -121,7 +121,7 @@ func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: muli %{{.*}}, %{{.*}} : i32 - "xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} + "lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -131,7 +131,7 @@ func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: subf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} + "lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -139,7 +139,7 @@ func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: subi %{{.*}}, %{{.*}} : i32 - "xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} + "lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -158,7 +158,7 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32> // CHECK: return - "xla_lhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) : (memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> () return } @@ -175,7 +175,7 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32> // CHECK: return - "xla_lhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) : (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () return } diff --git a/tests/lhlo-legalize-to-gpu.mlir b/tests/lhlo-legalize-to-gpu.mlir index e996581..02ad365 100644 --- a/tests/lhlo-legalize-to-gpu.mlir +++ b/tests/lhlo-legalize-to-gpu.mlir @@ -3,11 +3,11 @@ func @reduce(%arg: memref<100x10xf32>, %init: memref, %result: memref<100xf32>) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<100x10xf32>, memref, memref<100xf32>) -> () return @@ -25,7 +25,7 @@ func @reduce(%arg: memref<100x10xf32>, // CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { // CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref // CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref -// CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () +// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () // CHECK: } // CHECK: gpu.terminator // CHECK: } diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 8ebfb6b..6981466 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: func @element_wise func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.add"(%lhs, %rhs, %result) + "lmhlo.add"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -19,7 +19,7 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, func @element_wise_with_dynamic_shape(%lhs: memref, %rhs: memref, %result: memref) { - "xla_lhlo.add"(%lhs, %rhs, %result) + "lmhlo.add"(%lhs, %rhs, %result) : (memref, memref, memref) -> () return } @@ -33,7 +33,7 @@ func @element_wise_with_dynamic_shape(%lhs: memref, // CHECK-LABEL: func @element_wise_scalar func @element_wise_scalar(%lhs: memref, %rhs: memref, %result: memref) { - "xla_lhlo.add"(%lhs, %rhs, %result) + "lmhlo.add"(%lhs, %rhs, %result) : (memref, memref, memref) -> () return } @@ -48,7 +48,7 @@ func @element_wise_scalar(%lhs: memref, %rhs: memref, // CHECK-LABEL: func @minf func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.minimum"(%lhs, %rhs, %result) + "lmhlo.minimum"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -63,7 +63,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @maxi func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.maximum"(%lhs, %rhs, %result) + "lmhlo.maximum"(%lhs, %rhs, %result) : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -78,7 +78,7 @@ func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // CHECK-LABEL: func @and func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.and"(%lhs, %rhs, %result) + "lmhlo.and"(%lhs, %rhs, %result) : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -91,7 +91,7 @@ func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.exponential"(%input, %result) + "lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -104,7 +104,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @log func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -116,7 +116,7 @@ func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @copy func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { - "xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () + "lmhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () return } // CHECK: linalg.generic @@ -128,7 +128,7 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { // CHECK-LABEL: func @float_cmp func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { - "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"} + "lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> () return } @@ -142,7 +142,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @int_cmp func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi1>) { - "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} + "lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> () return } @@ -156,7 +156,7 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // CHECK-LABEL: func @select func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.select"(%pred, %lhs, %rhs, %result) + "lmhlo.select"(%pred, %lhs, %rhs, %result) : (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -170,7 +170,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @iota func @iota(%out: memref<7x10xf32>) { - "xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> () + "lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> () return } // CHECK: linalg.indexed_generic @@ -186,7 +186,7 @@ func @iota(%out: memref<7x10xf32>) { // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func @broadcast_scalar(%operand: memref, %result: memref<4x2x1xf32>) { - "xla_lhlo.broadcast"(%operand, %result) { + "lmhlo.broadcast"(%operand, %result) { broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> } : (memref, memref<4x2x1xf32>) -> () return @@ -203,7 +203,7 @@ func @broadcast_scalar(%operand: memref, %result: memref<4x2x1xf32>) { // CHECK-LABEL: func @broadcast func @broadcast(%operand: memref<4x?x16xf32>, %result: memref<4x2x1x4x?x16xf32>) { - "xla_lhlo.broadcast"(%operand, %result) { + "lmhlo.broadcast"(%operand, %result) { broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> } : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> () return @@ -220,7 +220,7 @@ func @broadcast(%operand: memref<4x?x16xf32>, // CHECK-LABEL: func @dynamic_broadcast_in_dim func @dynamic_broadcast_in_dim(%operand: memref, %result: memref) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64> } : (memref, memref) -> () return @@ -237,7 +237,7 @@ func @dynamic_broadcast_in_dim(%operand: memref, // CHECK-LABEL: func @static_broadcast_in_dim_no_expansion func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>, %result: memref<5x10xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[0]> : tensor<1xi64> } : (memref<5xf32>, memref<5x10xf32>) -> () return @@ -255,7 +255,7 @@ func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>, // CHECK-LABEL: func @static_broadcast_in_dim_expansion func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>, %result: memref<5x10x100xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[2, 0]> : tensor<2xi64> } : (memref<1x5xf32>, memref<5x10x100xf32>) -> () return @@ -274,7 +274,7 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>, // CHECK-LABEL: func @static_broadcast_in_dim_scalar func @static_broadcast_in_dim_scalar(%operand: memref, %result: memref<5x10xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[]> : tensor<0xi64> } : (memref, memref<5x10xf32>) -> () return @@ -291,7 +291,7 @@ func @static_broadcast_in_dim_scalar(%operand: memref, // CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, %result: memref<1x5xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[0]> : tensor<1xi64> } : (memref<1xf32>, memref<1x5xf32>) -> () return @@ -307,7 +307,7 @@ func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, // CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>, %result: memref<5x5xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (memref<1xf32>, memref<5x5xf32>) -> () return @@ -323,7 +323,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>, // CHECK-LABEL: func @constant func @constant(%value: memref) { - "xla_lhlo.constant"(%value) { + "lmhlo.constant"(%value) { value = dense<10> : tensor } : (memref) -> () return @@ -335,7 +335,7 @@ func @constant(%value: memref) { // CHECK-LABEL: func @absf func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -348,7 +348,7 @@ func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @absi func @absi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -364,7 +364,7 @@ func @absi(%input: memref<2x2xi32>, // CHECK-LABEL: func @ceil func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -376,7 +376,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_i32_to_f32 func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -389,7 +389,7 @@ func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_i16_to_i32 func @convert_i16_to_i32(%input: memref<2x2xi16>, %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -401,7 +401,7 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>, // CHECK-LABEL: func @convert_i32_to_i16 func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> () return } // CHECK: linalg.generic @@ -413,7 +413,7 @@ func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) { // CHECK-LABEL: func @convert_f32_to_f64 func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> () return } // CHECK: linalg.generic @@ -425,7 +425,7 @@ func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) { // CHECK-LABEL: func @convert_f64_to_f32 func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -437,7 +437,7 @@ func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_i32_to_i32 func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -448,7 +448,7 @@ func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @convert_f32_to_f32 func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -459,7 +459,7 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_f32_to_i32 func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) + "lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xi32>) -> () return } @@ -472,7 +472,7 @@ func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @cos func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -485,7 +485,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sin func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sine"(%input, %result) + "lmhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -498,7 +498,7 @@ func @sin(%input: memref<2x2xf32>, // CHECK-LABEL: func @negf func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -510,7 +510,7 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @negi func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -524,7 +524,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @rem func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.remainder"(%lhs, %rhs, %result) + "lmhlo.remainder"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -537,7 +537,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @rsqrt func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -549,7 +549,7 @@ func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sign func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -562,7 +562,7 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sqrt func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -574,7 +574,7 @@ func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @tanh func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -588,7 +588,7 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @complex(%real: memref<2x2xf32>, %imag: memref<2x2xf32>, %cplx: memref<2x2xcomplex>) { - "xla_lhlo.complex"(%real, %imag, %cplx) + "lmhlo.complex"(%real, %imag, %cplx) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex>) -> () return } @@ -602,7 +602,7 @@ func @complex(%real: memref<2x2xf32>, // CHECK-LABEL: func @real func @real(%cplx: memref<2x2xcomplex>, %real: memref<2x2xf32>) { - "xla_lhlo.real"(%cplx, %real) + "lmhlo.real"(%cplx, %real) : (memref<2x2xcomplex>, memref<2x2xf32>) -> () return } @@ -616,7 +616,7 @@ func @real(%cplx: memref<2x2xcomplex>, // CHECK-LABEL: func @imag func @imag(%cplx: memref<2x2xcomplex>, %imag: memref<2x2xf32>) { - "xla_lhlo.imag"(%cplx, %imag) + "lmhlo.imag"(%cplx, %imag) : (memref<2x2xcomplex>, memref<2x2xf32>) -> () return } @@ -629,7 +629,7 @@ func @imag(%cplx: memref<2x2xcomplex>, // CHECK: func @slice(%[[IN:.*]]: memref, %[[OUT:.*]]: memref) func @slice(%operand: memref, %result: memref) { - "xla_lhlo.slice"(%operand, %result) { + "lmhlo.slice"(%operand, %result) { start_indices = dense<[0,1]> : tensor<2xi64>, limit_indices = dense<[2,3]> : tensor<2xi64>, strides = dense<[1,1]> : tensor<2xi64> @@ -653,7 +653,7 @@ func @slice(%operand: memref, %result: memref) { // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-LABEL: func @reshape_3D_2D func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) + "lmhlo.reshape"(%arg0, %arg1) : (memref<12x1x42xi32>, memref<12x42xi32>) -> () return } @@ -666,7 +666,7 @@ func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> // CHECK-LABEL: func @reshape_4D_2D func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) + "lmhlo.reshape"(%arg0, %arg1) : (memref<12x42x1x1xi32>, memref<12x42xi32>) -> () return } @@ -679,7 +679,7 @@ func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // CHECK-LABEL: func @reshape_2D_4D func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) + "lmhlo.reshape"(%arg0, %arg1) : (memref<12x42xi32>, memref<12x1x42x1xi32>) -> () return } @@ -692,7 +692,7 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reverse func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) { - "xla_lhlo.reverse"(%arg0, %arg1) { + "lmhlo.reverse"(%arg0, %arg1) { dimensions = dense<1> : tensor<1xi64> } : (memref<2x3xf32>, memref<2x3xf32>) -> () return @@ -710,15 +710,15 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m // CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64> // CHECK-SAME: strides = [2, 1]} // With all atributes explicitly specified. - "xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () + "lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () // Dilation left unspecified, sets default dilation since linalg expects it. // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) // CHECK-SAME: dilations = [1, 1] // Padding is not set if it's zero. // CHECK-NOT: padding - "xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () + "lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () - "xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () + "lmhlo.terminator"() : () -> () } diff --git a/tests/lhlo-legalize-to-llvm.mlir b/tests/lhlo-legalize-to-llvm.mlir index a9759c0..a25a508 100644 --- a/tests/lhlo-legalize-to-llvm.mlir +++ b/tests/lhlo-legalize-to-llvm.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @static_memref_cast func @static_memref_cast(%buf : memref<10x1x5xf32>) { - %0 = xla_lhlo.static_memref_cast %buf + %0 = lmhlo.static_memref_cast %buf : memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]> return } @@ -38,7 +38,7 @@ func @dynamic_memref_cast(%buf : memref) { %size_Y = constant 50 : index %stride_X = constant 1 : index %stride_Y = constant 0 : index - %0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y] + %0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y] : memref -> memref return } diff --git a/tests/lhlo-legalize-to-parallel-loops.mlir b/tests/lhlo-legalize-to-parallel-loops.mlir index a3d76ef..1530f59 100644 --- a/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tests/lhlo-legalize-to-parallel-loops.mlir @@ -3,11 +3,11 @@ func @reduce(%arg: memref<100x10x5xf32>, %init: memref, %result: memref<100x5xf32>) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<100x10x5xf32>, memref, memref<100x5xf32>) -> () return @@ -35,7 +35,7 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } @@ -49,11 +49,11 @@ func @reduce(%arg: memref<100x10x5xf32>, func @reduce_no_outer_loop(%arg: memref<100xf32>, %init: memref, %result: memref<1xf32>) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<100xf32>, memref, memref<1xf32>) -> () return @@ -76,7 +76,7 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] // CHECK: } @@ -88,11 +88,11 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, func @dynamic_reduce(%arg: memref, %init: memref, %result: memref) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref, memref, memref) -> () return @@ -121,7 +121,7 @@ func @dynamic_reduce(%arg: memref, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } @@ -135,11 +135,11 @@ func @dynamic_reduce(%arg: memref, func @reduce_window(%arg: memref<112x112xf32>, %init: memref, %result: memref<56x56xf32>) { - "xla_lhlo.reduce_window"(%arg, %init, %result) ( { + "lmhlo.reduce_window"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.maximum"(%lhs, %rhs, %res) + "lmhlo.maximum"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () }) { padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, window_dimensions = dense<[3, 3]> : tensor<2xi64>, @@ -189,7 +189,7 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index e793e2a..30ff965 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: func @ceil func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -12,7 +12,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point values}} - "xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -20,7 +20,7 @@ func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @cos func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -28,7 +28,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @cos func @cos(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () return } @@ -36,7 +36,7 @@ func @cos(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -44,7 +44,7 @@ func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @sin func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -52,7 +52,7 @@ func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sin func @sin(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { - "xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + "lmhlo.sine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () return } @@ -60,7 +60,7 @@ func @sin(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -68,7 +68,7 @@ func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @add_memrefs func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -76,7 +76,7 @@ func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1 // CHECK-LABEL: func @abs_memref func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -84,7 +84,7 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @convert_memref func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () { - "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> () + "lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> () return } @@ -92,7 +92,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () { func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () { // expected-error@+1{{requires the same shape for all operands}} - "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> () + "lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> () return } @@ -100,7 +100,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () { // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -108,7 +108,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { - "xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + "lmhlo.exponential"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () return } @@ -116,7 +116,7 @@ func @exp(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -124,7 +124,7 @@ func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @log_memref func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -132,7 +132,7 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @log_memref func @log_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.log"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.log"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -140,7 +140,7 @@ func @log_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -148,7 +148,7 @@ func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @neg_memref func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -156,7 +156,7 @@ func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @rsqrt_memref func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -164,7 +164,7 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @rsqrt_memref func @rsqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.rsqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -172,7 +172,7 @@ func @rsqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -180,7 +180,7 @@ func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @sqrt_memref func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -188,7 +188,7 @@ func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @sqrt_memref func @sqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.sqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -196,7 +196,7 @@ func @sqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) - func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -204,7 +204,7 @@ func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @sign_memref func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -212,7 +212,7 @@ func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @tanh_memref func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -220,7 +220,7 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @tanh_memref func @tanh_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.tanh"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -228,15 +228,15 @@ func @tanh_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) - func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } // ----- func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { - // expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}} - "xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> () + // expected-error@+1{{'lmhlo.tanh' op requires all operands to have the same type}} + "lmhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> () return } @@ -244,7 +244,7 @@ func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { // CHECK-LABEL: func @add_memref func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -252,7 +252,7 @@ func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @div_memref func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -260,7 +260,7 @@ func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @max_memref func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -268,7 +268,7 @@ func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @min_memref func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -276,7 +276,7 @@ func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @mul_memref func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -284,7 +284,7 @@ func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @sub_memref func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -292,7 +292,7 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @and_memref func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { - "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () return } @@ -300,7 +300,7 @@ func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32 // CHECK-LABEL: func @and_memref func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { - "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () return } @@ -308,7 +308,7 @@ func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -316,7 +316,7 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @or_memref func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { - "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () return } @@ -324,7 +324,7 @@ func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32> // CHECK-LABEL: func @or_memref func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { - "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () return } @@ -332,7 +332,7 @@ func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) - func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -340,7 +340,7 @@ func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32> // CHECK-LABEL: func @xor_memref func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { - "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () return } @@ -348,7 +348,7 @@ func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32 // CHECK-LABEL: func @xor_memref func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { - "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () return } @@ -356,7 +356,7 @@ func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -364,7 +364,7 @@ func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @broadcast_in_dim_memref func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () { - "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> () + "lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> () return } @@ -372,7 +372,7 @@ func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) - // CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref func @broadcast_in_dim_zero_rank_memref(%arg0: memref, %out: memref<1x2x3xi32>) -> () { - "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref, memref<1x2x3xi32>) -> () + "lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref, memref<1x2x3xi32>) -> () return } @@ -381,10 +381,10 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref, %out: memref<1x2x3xi // CHECK-LABEL: func @reduce_memref func @reduce_memref(%input: memref<10xf32>, %init: memref, %out: memref<1xf32>) -> () { - "xla_lhlo.reduce"(%input, %init, %out) ( { + "lmhlo.reduce"(%input, %init, %out) ( { ^bb0(%arg1: memref, %arg2: memref, %result: memref): - "xla_lhlo.add"(%arg1, %arg2, %result) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.add"(%arg1, %arg2, %result) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref, memref<1xf32>) -> () return } @@ -393,14 +393,14 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref, %out: memref<1xf // CHECK-LABEL: func @fusion_memref func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.fusion"() ( { + "lmhlo.fusion"() ( { %0 = tensor_load %input1 : memref<10xf32> %1 = tensor_load %input2 : memref<10xf32> %2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %3 = tensor_load %input3 : memref<10xf32> %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> tensor_store %4, %out : memref<10xf32> - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) : () -> () return } @@ -409,18 +409,18 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m // CHECK-LABEL: func @case_memref func @case_memref(%index: memref, %operand_1: memref, %operand_2: memref, %operand_3: memref, %out: memref) -> () { - "xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { + "lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { ^bb0(%arg0: memref): - "xla_lhlo.negate"(%arg0, %out) : (memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.negate"(%arg0, %out) : (memref, memref) -> () + "lmhlo.terminator"() : () -> () }, { ^bb0(%arg0: memref): - "xla_lhlo.copy"(%arg0, %out) : (memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.copy"(%arg0, %out) : (memref, memref) -> () + "lmhlo.terminator"() : () -> () }, { ^bb0(%arg0: memref): - "xla_lhlo.add"(%arg0, %arg0, %out) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.add"(%arg0, %arg0, %out) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () } ) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>} : (memref, memref, memref, memref, memref) -> () @@ -430,7 +430,7 @@ func @case_memref(%index: memref, %operand_1: memref, %operand_2: memr // ----- func @static_memref_cast(%in: memref<10x1xf32>) { - %out = xla_lhlo.static_memref_cast %in + %out = lmhlo.static_memref_cast %in : memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]> return } @@ -440,7 +440,7 @@ func @static_memref_cast(%in: memref<10x1xf32>) { func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) { // expected-error @+1 {{operand must have static shape}} - %out = xla_lhlo.static_memref_cast %in + %out = lmhlo.static_memref_cast %in : memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]> return } @@ -449,7 +449,7 @@ func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) { func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) { // expected-error @+1 {{result must have static shape}} - %out = xla_lhlo.static_memref_cast %in + %out = lmhlo.static_memref_cast %in : memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]> return } @@ -459,7 +459,7 @@ func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) { func @dynamic_memref_cast(%in: memref) { %size = constant 10 : index %step = constant 1 : index - %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] + %out = lmhlo.dynamic_memref_cast %in(%size)[%step] : memref -> memref return } @@ -471,7 +471,7 @@ func @dynamic_memref_cast_incompatible_result_type(%in: memref) { // expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}} %size = constant 10 : index %step = constant 1 : index - %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] + %out = lmhlo.dynamic_memref_cast %in(%size)[%step] : memref -> memref return } @@ -483,19 +483,19 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>, // CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>, // CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref - // CHECK-NEXT: [[DYN_VEC:%.*]] = xla_lhlo.reshape_memref_cast [[UNRANKED]] + // CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]] // CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref - %dyn_vec = xla_lhlo.reshape_memref_cast %unranked(%shape1) + %dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1) : (memref<*xf32>, memref<1xi32>) -> memref - // CHECK-NEXT: [[DYN_MAT:%.*]] = xla_lhlo.reshape_memref_cast [[DYN_VEC]] + // CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]] // CHECK-SAME: : (memref, memref<2xi32>) -> memref - %dyn_mat = xla_lhlo.reshape_memref_cast %dyn_vec(%shape2) + %dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2) : (memref, memref<2xi32>) -> memref - // CHECK-NEXT: {{%.*}} = xla_lhlo.reshape_memref_cast [[DYN_MAT]] + // CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]] // CHECK-SAME: : (memref, memref) -> memref<*xf32> - %new_unranked = xla_lhlo.reshape_memref_cast %dyn_mat(%shape3) + %new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3) : (memref, memref) -> memref<*xf32> return } @@ -505,7 +505,7 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>, func @reshape_memref_cast_element_type_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{element types of source and destination memref types should be the same}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<*xf32>, memref<1xi32>) -> memref } @@ -514,7 +514,7 @@ func @reshape_memref_cast_element_type_mismatch( func @reshape_memref_cast_dst_ranked_shape_unranked( %buf: memref<*xf32>, %shape: memref) { // expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<*xf32>, memref) -> memref return } @@ -524,7 +524,7 @@ func @reshape_memref_cast_dst_ranked_shape_unranked( func @reshape_memref_cast_dst_shape_rank_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{length of shape operand differs from the result's memref rank}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<*xf32>, memref<1xi32>) -> memref return } @@ -535,7 +535,7 @@ func @reshape_memref_cast_affine_map_is_not_identity( %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>, %shape: memref<1xi32>) { // expected-error @+1 {{operand memref type should have identity affine map}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>) -> memref<8xf32> return @@ -545,7 +545,7 @@ func @reshape_memref_cast_affine_map_is_not_identity( // CHECK-LABEL: func @atan2_memrefs func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -553,7 +553,7 @@ func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref // CHECK-LABEL: func @atan2_memrefs func @atan2_memrefs(%arg0: memref<1xcomplex>, %arg1: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { - "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>, memref<1xcomplex>) -> () + "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>, memref<1xcomplex>) -> () return } @@ -561,7 +561,7 @@ func @atan2_memrefs(%arg0: memref<1xcomplex>, %arg1: memref<1xcomplex> func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -569,7 +569,7 @@ func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref // CHECK-LABEL: func @bitcast_convert_memrefs func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> () + "lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> () return } @@ -577,7 +577,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () { // expected-error@+1{{requires the same shape for all operands}} - "xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> () + "lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> () return } @@ -585,7 +585,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> // CHECK-LABEL: func @clz_memrefs func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -593,7 +593,7 @@ func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @expm1_memrefs func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -601,7 +601,7 @@ func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @expm1_memrefs func @expm1_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { - "xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () + "lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () return } @@ -609,7 +609,7 @@ func @expm1_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -617,7 +617,7 @@ func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // expected-error@+1{{must be memref of floating-point values}} - "xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -625,7 +625,7 @@ func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @imag_memrefs func @imag_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () + "lmhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () return } @@ -633,7 +633,7 @@ func @imag_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error@+1{{must be memref of complex-type values}} - "xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -641,7 +641,7 @@ func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @real_memrefs func @real_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () + "lmhlo.real"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () return } @@ -649,7 +649,7 @@ func @real_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error@+1{{must be memref of complex-type values}} - "xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -657,7 +657,7 @@ func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @is_finite_memrefs func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () { - "xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> () + "lmhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> () return } @@ -665,7 +665,7 @@ func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () { // CHECK-LABEL: func @log1p_memrefs func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -673,7 +673,7 @@ func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @log1p_memrefs func @log1p_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { - "xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () + "lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () return } @@ -681,7 +681,7 @@ func @log1p_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -689,7 +689,7 @@ func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @not_memrefs func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -697,7 +697,7 @@ func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @not_memrefs func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () { - "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> () + "lmhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> () return } @@ -705,7 +705,7 @@ func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () { func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -713,7 +713,7 @@ func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @popcnt_memrefs func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -721,7 +721,7 @@ func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -729,7 +729,7 @@ func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @reduce_precision_memrefs func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> () return } @@ -737,7 +737,7 @@ func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> // CHECK-LABEL: func @round_memrefs func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -745,7 +745,7 @@ func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // expected-error@+1{{must be memref of floating-point values}} - "xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -753,7 +753,7 @@ func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @shift_left_memrefs func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -761,7 +761,7 @@ func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: m func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -769,7 +769,7 @@ func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: m // CHECK-LABEL: func @shift_right_arithmetic_memrefs func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -777,7 +777,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -785,7 +785,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, // CHECK-LABEL: func @shift_right_logical_memrefs func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -793,7 +793,7 @@ func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %a func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -801,14 +801,14 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a // CHECK-LABEL: func @all_reduce_memrefs func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () { - "xla_lhlo.all_reduce"(%arg0, %arg_out) ({ + "lmhlo.all_reduce"(%arg0, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): %max = mhlo.maximum %lhs, %rhs : tensor "mhlo.return"(%max) : (tensor) -> () }) { replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> () - "xla_lhlo.all_reduce"(%arg0, %arg_out) ({ + "lmhlo.all_reduce"(%arg0, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): %max = mhlo.maximum %lhs, %rhs : tensor "mhlo.return"(%max) : (tensor) -> () @@ -826,11 +826,11 @@ func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () // CHECK-LABEL: func @collective_permute_memrefs func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () { - "xla_lhlo.collective_permute"(%arg0, %arg_out) { + "lmhlo.collective_permute"(%arg0, %arg_out) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> } : (memref<128x32xf32>, memref<128x32xf32>) -> () - "xla_lhlo.collective_permute"(%arg0, %arg_out) { + "lmhlo.collective_permute"(%arg0, %arg_out) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, channel_id = { handle = 5 : i64, type = 2 : i64 } } : (memref<128x32xf32>, memref<128x32xf32>) -> () @@ -841,7 +841,7 @@ func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128 // CHECK-LABEL: func @fft_memrefs func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex>) -> () { - "xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex>) -> () + "lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex>) -> () return } @@ -852,7 +852,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>, %grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>, %grad_offset: memref<8xf32>) -> () { - "xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + "lmhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () return @@ -863,7 +863,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, // CHECK-LABEL: func @batch_norm_inference_memrefs func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, %arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () { - "xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + "lmhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> () return } @@ -874,7 +874,7 @@ func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, %output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>, %batch_var: memref<8xf32>) -> () { - "xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + "lmhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () return } @@ -883,8 +883,8 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3 // CHECK-LABEL: func @cholesky_memrefs func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () { - "xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () - "xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () + "lmhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () + "lmhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () return } @@ -892,7 +892,7 @@ func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291x // CHECK-LABEL: func @infeed_memrefs func @infeed_memrefs(%arg_out: memref<3xf32>) -> () { - "xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> () + "lmhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> () return } @@ -900,7 +900,7 @@ func @infeed_memrefs(%arg_out: memref<3xf32>) -> () { // CHECK-LABEL: func @outfeed_memrefs func @outfeed_memrefs(%arg0: memref<3xf32>) -> () { - "xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> () + "lmhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> () return } @@ -908,7 +908,7 @@ func @outfeed_memrefs(%arg0: memref<3xf32>) -> () { // CHECK-LABEL: func @replica_id_memrefs func @replica_id_memrefs(%arg_out: memref) -> () { - "xla_lhlo.replica_id"(%arg_out) : (memref) -> () + "lmhlo.replica_id"(%arg_out) : (memref) -> () return } @@ -916,7 +916,7 @@ func @replica_id_memrefs(%arg_out: memref) -> () { // CHECK-LABEL: func @triangular_solve_memrefs func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () { - "xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} + "lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> () return } @@ -925,9 +925,9 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, % // CHECK-LABEL: func @while_memrefs func @while_memrefs(%arg0: memref, %arg_out: memref) -> () { - "xla_lhlo.while"(%arg0, %arg_out) ( - { ^bb0(%arg: memref, %cond: memref): "xla_lhlo.terminator"() : () -> () }, - { ^bb0(%arg: memref, %body_out: memref): "xla_lhlo.terminator"() : () -> () } + "lmhlo.while"(%arg0, %arg_out) ( + { ^bb0(%arg: memref, %cond: memref): "lmhlo.terminator"() : () -> () }, + { ^bb0(%arg: memref, %body_out: memref): "lmhlo.terminator"() : () -> () } ) : (memref, memref) -> () return } @@ -936,9 +936,9 @@ func @while_memrefs(%arg0: memref, %arg_out: memref) -> () { // CHECK-LABEL: func @while_memrefs func @while_memrefs(%arg0: memref, %arg1: memref<5xf32>, %arg0_out: memref, %arg1_out: memref<5xf32>) -> () { - "xla_lhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) ( - { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %cond: memref): "xla_lhlo.terminator"() : () -> () }, - { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %body_out0: memref, %body_out1: memref<5xf32>): "xla_lhlo.terminator"() : () -> () } + "lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) ( + { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %cond: memref): "lmhlo.terminator"() : () -> () }, + { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %body_out0: memref, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () } ) : (memref, memref<5xf32>, memref, memref<5xf32>) -> () return } @@ -947,7 +947,7 @@ func @while_memrefs(%arg0: memref, %arg1: memref<5xf32>, %arg0_out: memref< // CHECK-LABEL: func @bitcast_memrefs func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () { - "xla_lhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> () + "lmhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> () return } @@ -956,7 +956,7 @@ func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () { // CHECK-LABEL: func @scatter_memrefs func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>, %updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () { - "xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({ + "lmhlo.scatter" (%input, %indices, %updates, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): // no predecessors %add = mhlo.add %lhs, %rhs : tensor "mhlo.return"(%add) : (tensor) -> () @@ -977,7 +977,7 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32 // CHECK-LABEL: func @map_memrefs func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () { - "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ + "lmhlo.map"(%arg0, %arg1, %arg_out) ({ ^bb0(%a: tensor, %b: tensor): %c = mhlo.add %a, %b : tensor "mhlo.return"(%c) : (tensor) -> () @@ -989,7 +989,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () { // expected-error@+1{{requires the same shape for all operands}} - "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ + "lmhlo.map"(%arg0, %arg1, %arg_out) ({ ^bb0(%a: tensor, %b: tensor): %c = mhlo.add %a, %b : tensor "mhlo.return"(%c) : (tensor) -> () @@ -1001,7 +1001,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref // CHECK-LABEL: func @rng_get_and_update_state_memrefs func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () { - "xla_lhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> () + "lmhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> () return } @@ -1010,7 +1010,7 @@ func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () { // CHECK-LABEL: func @sort_memrefs func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { - "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () @@ -1023,7 +1023,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, // CHECK-LABEL: func @sort_memrefs func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { - "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () @@ -1036,7 +1036,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, // CHECK-LABEL: func @sort_memrefs func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { - "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> ()