Rename xla_lhlo dialect into lmhlo

Following on the plan of isolating the compiler/mlir/hlo directory.
Another xla_lhlo dialect will be created under compiler/mlir/xla/ later.

PiperOrigin-RevId: 320210326
This commit is contained in:
Mehdi Amini 2020-07-08 17:05:32 +00:00 committed by Mehdi Amini
parent b076e018a8
commit 7c4a5d62b5
26 changed files with 566 additions and 550 deletions

View File

@ -35,18 +35,18 @@ class OpBuilder;
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc"
namespace xla_lhlo {
namespace lmhlo {
class XlaLhloDialect : public Dialect {
class LmhloDialect : public Dialect {
public:
explicit XlaLhloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "xla_lhlo"; }
explicit LmhloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "lmhlo"; }
};
#define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
} // namespace xla_lhlo
} // namespace lmhlo
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_

View File

@ -38,8 +38,8 @@ include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInte
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
def LHLO_Dialect : Dialect {
let name = "xla_lhlo";
let cppNamespace = "xla_lhlo";
let name = "lmhlo";
let cppNamespace = "lmhlo";
}
//===----------------------------------------------------------------------===//
@ -253,7 +253,7 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example,
// A tuple-like pattern match syntax could work:
// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
// ...
// }, {
// ...
@ -337,7 +337,7 @@ def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
Example:
```mlir
%buf_transformed =
xla_lhlo.static_memref_cast %buf
lmhlo.static_memref_cast %buf
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
@ -379,7 +379,7 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
Example:
```mlir
%buf_transformed =
xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
// The result of the op is a type-erased memref with `[%size_X, %size_Y]`
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited

View File

@ -34,7 +34,7 @@ using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
#define MAP_HLO_TO_LHLO(OpName) \
template <> \
struct HloToLhloOpImpl<mhlo::OpName> { \
using Type = xla_lhlo::OpName; \
using Type = lmhlo::OpName; \
}
MAP_HLO_TO_LHLO(AbsOp);

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace impl {
// A struct to map LhloBinaryOpTy type to the corresponding floating-point and
@ -33,32 +33,32 @@ template <typename LhloBinaryOpTy>
struct LhloToScalarOp;
template <>
struct LhloToScalarOp<xla_lhlo::AddOp> {
struct LhloToScalarOp<lmhlo::AddOp> {
using FOp = ::mlir::AddFOp;
using IOp = ::mlir::AddIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::CompareOp> {
struct LhloToScalarOp<lmhlo::CompareOp> {
using FOp = ::mlir::CmpFOp;
using IOp = ::mlir::CmpIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::DivOp> {
struct LhloToScalarOp<lmhlo::DivOp> {
using FOp = ::mlir::DivFOp;
using IOp = ::mlir::SignedDivIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::MulOp> {
struct LhloToScalarOp<lmhlo::MulOp> {
using FOp = ::mlir::MulFOp;
using IOp = ::mlir::MulIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::RemOp> {
struct LhloToScalarOp<lmhlo::RemOp> {
using FOp = ::mlir::RemFOp;
using IOp = ::mlir::SignedRemIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::SubOp> {
struct LhloToScalarOp<lmhlo::SubOp> {
using FOp = ::mlir::SubFOp;
using IOp = ::mlir::SubIOp;
};
@ -116,8 +116,9 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
@ -125,7 +126,7 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
loc, result_types, args, b);
}
if (element_type.isa<IntegerType>()) {
// xla_lhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
// lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>();
@ -133,15 +134,16 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
lhs, zero_intval);
auto neg_val = b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
loc, result_types, args, b);
@ -205,30 +207,33 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc,
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return args.front();
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args,
@ -236,21 +241,23 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RealOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ImagOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type sourceType = args.front().getType();
@ -288,8 +295,9 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
// Dot Op converter from lhlo to affine only accepts float and integer types.
const auto& lhs = args[0];
@ -312,16 +320,18 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SinOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
loc, result_types, args, b);
@ -361,38 +371,40 @@ struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
};
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
result_types, args,
b);
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types,
args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
result_types, args,
b);
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types,
args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
@ -400,27 +412,28 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
loc, result_types, args, b);
}
if (element_type.isa<IntegerType>()) {
// xla_lhlo.neg(x, result) -> result = sub(0, x)
// lmhlo.neg(x, result) -> result = sub(0, x)
Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>();
auto zero_intval =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
return b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RsqrtOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
@ -428,8 +441,9 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
@ -442,16 +456,18 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SqrtOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
loc, result_types, args, b);
@ -460,10 +476,10 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
} // namespace impl
struct XlaOpToStdScalarOp {
// Implementation for LHLO ops except xla_lhlo::CompareOp.
// Implementation for LHLO ops except lmhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
@ -475,7 +491,7 @@ struct XlaOpToStdScalarOp {
// Implementation for HLO ops except mhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
@ -483,13 +499,13 @@ struct XlaOpToStdScalarOp {
args, b);
}
// Implementation for xla_lhlo::CompareOp.
// Implementation for lmhlo::CompareOp.
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
LhloOpTy, xla_lhlo::CompareOp>::value>>
static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
LhloOpTy, lmhlo::CompareOp>::value>>
static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
@ -500,12 +516,12 @@ struct XlaOpToStdScalarOp {
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
};
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_

View File

@ -60,7 +60,7 @@ std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
} // namespace mhlo
namespace xla_lhlo {
namespace lmhlo {
// Lowers from LHLO dialect to Affine dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass();
@ -92,7 +92,7 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
// Lowers from LHLO dialect to parallel loops.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
} // namespace xla_lhlo
} // namespace lmhlo
namespace xla {

View File

@ -75,14 +75,14 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
} // namespace mhlo
namespace xla_lhlo {
namespace lmhlo {
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
LLVMTypeConverter *converter,
OwningRewritePatternList *patterns);
} // namespace xla_lhlo
} // namespace lmhlo
namespace xla_chlo {

View File

@ -21,4 +21,4 @@ limitations under the License.
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
xla_chlo_ops;
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;
static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops;

View File

@ -46,9 +46,9 @@ limitations under the License.
namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc"
namespace xla_lhlo {
namespace lmhlo {
XlaLhloDialect::XlaLhloDialect(MLIRContext *context)
LmhloDialect::LmhloDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
@ -138,5 +138,5 @@ void FusionOp::build(OpBuilder &builder, OperationState &result,
FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
}
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -44,7 +44,7 @@ template <typename T>
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
using StdReturnOpConverter =
detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
xla_lhlo::CopyOp, true>;
lmhlo::CopyOp, true>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand,
@ -149,7 +149,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
Value transformed_operand =
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
rewriter.create<xla_lhlo::BroadcastInDimOp>(
rewriter.create<lmhlo::BroadcastInDimOp>(
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
rewriter.replaceOp(op, {resultBuffer});
@ -161,7 +161,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
// Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is
// necessary.
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
@ -214,7 +214,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext()));
auto transformed_operand = b->create<xla_lhlo::DynamicMemRefCastOp>(
auto transformed_operand = b->create<lmhlo::DynamicMemRefCastOp>(
loc, type_erased_memref_type, operand, sizes, strides);
return transformed_operand;
}
@ -239,7 +239,7 @@ struct HloToLhloDynamicReshapeConverter
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<xla_lhlo::ReshapeMemRefCastOp>(
rewriter.replaceOpWithNewOp<lmhlo::ReshapeMemRefCastOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}
@ -266,8 +266,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
buffer_args.push_back(
InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
}
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
loc, llvm::None, buffer_args, op.getAttrs());
auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
op.getAttrs());
// Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
@ -292,7 +292,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
}
// Insert terminator at the end.
rewriter.setInsertionPointToEnd(&entry_block);
rewriter.create<xla_lhlo::TerminatorOp>(loc);
rewriter.create<lmhlo::TerminatorOp>(loc);
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
@ -321,8 +321,8 @@ class HloToLhloTensorStoreOpConverter
LogicalResult matchAndRewrite(
mlir::TensorStoreOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
op, llvm::None, operands.front(), operands.back());
rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
operands.back());
return success();
}
};
@ -336,7 +336,7 @@ class HloToLhloTensorStoreOpConverter
// %arg1: memref<2x2xf32>,
// %arg2: memref<2x2xf32>,
// %arg3: memref<2x2xf32>) {
// "xla_lhlo.fusion"() ({
// "lmhlo.fusion"() ({
// %0 = tensor_load %arg1 : memref<2x2xf32>
// %1 = tensor_load %arg2 : memref<2x2xf32>
// %2 = "mhlo.add"(%0, %1) :
@ -345,7 +345,7 @@ class HloToLhloTensorStoreOpConverter
// %4 = "mhlo.multiply"(%2, %3) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// tensor_store %4, %arg3 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> ()
// "lmhlo.terminator"() : () -> ()
// }) : () -> ()
// return
// }
@ -355,13 +355,13 @@ class HloToLhloTensorStoreOpConverter
// %arg1: memref<2x2xf32>,
// %arg2: memref<2x2xf32>,
// %arg3: memref<2x2xf32>) {
// "xla_lhlo.fusion"() ( {
// "lmhlo.fusion"() ( {
// %0 = alloc() : memref<2x2xf32>
// "xla_lhlo.add"(%arg1, %arg2, %0) :
// "lmhlo.add"(%arg1, %arg2, %0) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.multiply"(%0, %arg0, %arg3) :
// "lmhlo.multiply"(%0, %arg0, %arg3) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.terminator"() : () -> ()
// "lmhlo.terminator"() : () -> ()
// }) : () -> ()
// return
// }
@ -382,13 +382,13 @@ class HloToLhloTensorStoreOpConverter
// %arg2: memref<4xf32>) {
// %0 = alloc() : memref<4xf32>
// "xla_lhlo.maximum"(%arg0, %arg1, %0) :
// "lmhlo.maximum"(%arg0, %arg1, %0) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// %1 = alloc() : memref<4xf32>
// "xla_lhlo.add"(%arg0, %0, %1) :
// "lmhlo.add"(%arg0, %0, %1) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.terminator"() : () -> ()
// "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
// "lmhlo.terminator"() : () -> ()
// }
struct HloLegalizeToLhlo
@ -406,7 +406,7 @@ struct HloLegalizeToLhlo
OwningRewritePatternList patterns;
auto& context = getContext();
ConversionTarget target(context);
target.addLegalDialect<xla_lhlo::XlaLhloDialect>();
target.addLegalDialect<lmhlo::LmhloDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<ModuleOp>();
target.addIllegalOp<mlir::TensorLoadOp>();
@ -441,12 +441,12 @@ struct HloLegalizeToLhlo
&converter, &patterns);
if (results_escape_function) {
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp,
/*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
&converter, &patterns);
} else {
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp,
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
&converter, &patterns);
}

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
// Removes LHLO copy operations that copy from allocated buffers to block
@ -34,7 +34,7 @@ struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
void runOnOperation() override {
llvm::SmallVector<mlir::Operation*, 2> eraseList;
auto operation = getOperation();
operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) {
operation->walk([&](mlir::lmhlo::CopyOp copyOp) {
// If this region contains more than one block, then ignore this copy
// operation.
if (copyOp.getParentRegion()->getBlocks().size() > 1) {
@ -101,5 +101,5 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass() {
static PassRegistration<LhloCopyRemoval> copy_removal_pass(
"lhlo-copy-removal", "Removes redundant LHLO copy operations");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
using linalg::LinalgOp;
@ -147,5 +147,5 @@ static PassRegistration<LhloFuseLinalg> legalize_pass(
"lhlo-fuse-linalg",
"Greedily fuse linalg ops obtained after LHLO lowering.");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
// Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
auto result =
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<DotOp>(
Value op_result = lmhlo::XlaOpToStdScalarOp::map<DotOp>(
op, element_type, {l, r, result}, &builder);
map_status = success(op_result != nullptr);
if (failed(map_status)) return;
@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
ValueRange induction_vars) {
auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
Value op_result = lmhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
op, element_type, {l, r}, &builder);
map_status = success(op_result != nullptr);
if (failed(map_status)) return;
@ -127,13 +127,13 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
BinaryOpConverter<xla_lhlo::AddOp>,
BinaryOpConverter<xla_lhlo::AndOp>,
BinaryOpConverter<xla_lhlo::DivOp>,
BinaryOpConverter<xla_lhlo::MaxOp>,
BinaryOpConverter<xla_lhlo::MinOp>,
BinaryOpConverter<xla_lhlo::MulOp>,
BinaryOpConverter<xla_lhlo::SubOp>,
BinaryOpConverter<lmhlo::AddOp>,
BinaryOpConverter<lmhlo::AndOp>,
BinaryOpConverter<lmhlo::DivOp>,
BinaryOpConverter<lmhlo::MaxOp>,
BinaryOpConverter<lmhlo::MinOp>,
BinaryOpConverter<lmhlo::MulOp>,
BinaryOpConverter<lmhlo::SubOp>,
DotOpConverter>(context);
// clang-format on
}
@ -157,5 +157,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() {
static PassRegistration<LhloLegalizeToAffine> legalize_pass(
"lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -38,7 +38,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
// A simple translation of LHLO reduce operations to a corresponding gpu
@ -173,7 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>();
gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
target.addIllegalOp<ReduceOp>();
auto func = getFunction();
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
@ -192,5 +192,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() {
static PassRegistration<LhloLegalizeToGpu> legalize_pass(
"lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
struct StaticMemRefCastOpConverter
@ -132,5 +132,5 @@ void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
*converter, options);
}
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
class TestLhloToLLVMPass
@ -42,7 +42,7 @@ class TestLhloToLLVMPass
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalDialect<XlaLhloDialect>();
target.addIllegalDialect<LmhloDialect>();
if (failed(applyFullConversion(m, target, patterns))) {
signalPassFailure();
@ -55,5 +55,5 @@ class TestLhloToLLVMPass
static PassRegistration<TestLhloToLLVMPass> legalize_lhlo_pass(
"test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM.");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
// Clones and adapts the code in `lhlo_block` that works on buffers and has a
@ -154,14 +154,14 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
return b->create<scf::ParallelOp>(loc, lower, upper, step);
}
// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
// Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops if there are
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
// contains the reduction operator.
//
// Example:
//
// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( {
// "lmhlo.reduce"(%buffer, %init_buf, %result) ( {
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// <LHLO ops>
// } ) {dimensions = dense<[1]> : tensor<1xi64>}
@ -187,12 +187,12 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
// } : f32
// scf.yield
// }
class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
public:
using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
lmhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
// TODO(b/137624192) Implement variadic reduce.
if (xla_reduce_op.out().size() != 1) return failure();
@ -226,7 +226,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// scf.yield
// }
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceOp xla_reduce_op,
lmhlo::ReduceOp xla_reduce_op,
ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_op.getLoc();
DenseSet<int> reducing_dims;
@ -314,7 +314,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// accumulator = reduction_operator(output[O], value)
// output[O] = accumulator
//
// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a
// Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
// scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops that traverese output
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse
@ -325,11 +325,11 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// func @reduce_window(%arg: memref<112x112xf32>,
// %init: memref<f32>,
// %result: memref<56x56xf32>) {
// "xla_lhlo.reduce_window"(%arg, %init, %result) ( {
// "lmhlo.reduce_window"(%arg, %init, %result) ( {
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// "xla_lhlo.maximum"(%lhs, %rhs, %res)
// "lmhlo.maximum"(%lhs, %rhs, %res)
// : (memref<f32>, memref<f32>, memref<f32>) -> ()
// "xla_lhlo.terminator"() : () -> ()
// "lmhlo.terminator"() : () -> ()
// }) {
// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
// window_dimensions = dense<[3, 3]> : tensor<2xi64>,
@ -359,12 +359,12 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// return
// }
class ReduceWindowOpConverter
: public OpConversionPattern<xla_lhlo::ReduceWindowOp> {
: public OpConversionPattern<lmhlo::ReduceWindowOp> {
public:
using OpConversionPattern<xla_lhlo::ReduceWindowOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
lmhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
scf::ParallelOp output_loop, window_loop;
std::tie(output_loop, window_loop) =
@ -383,7 +383,7 @@ class ReduceWindowOpConverter
private:
std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow(
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
lmhlo::ReduceWindowOp xla_reduce_window_op,
ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_window_op.getLoc();
Value init_value =
@ -415,9 +415,8 @@ class ReduceWindowOpConverter
}
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
scf::ParallelOp output_loop, scf::ParallelOp window_loop,
ConversionPatternRewriter* rewriter) const {
lmhlo::ReduceWindowOp xla_reduce_window_op, scf::ParallelOp output_loop,
scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
rewriter->setInsertionPointToStart(window_loop.getBody());
auto loc = xla_reduce_window_op.getLoc();
@ -481,12 +480,12 @@ class ReduceWindowOpConverter
// initialized_flag = true
// output(selected_index) = scatter(output(selected_index), source(S))
class SelectAndScatterOpConverter
: public OpConversionPattern<xla_lhlo::SelectAndScatterOp> {
: public OpConversionPattern<lmhlo::SelectAndScatterOp> {
public:
using OpConversionPattern<xla_lhlo::SelectAndScatterOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::SelectAndScatterOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
auto loc = s_and_s_op.getLoc();
InitializeOutput(s_and_s_op, &rewriter);
@ -515,7 +514,7 @@ class SelectAndScatterOpConverter
}
private:
void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op,
void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
@ -533,7 +532,7 @@ class SelectAndScatterOpConverter
SmallVector<Value, 2> window_ivs;
scf::ForOp inner_loop;
};
WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op,
WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,
scf::ParallelOp loop_over_src,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
@ -598,7 +597,7 @@ class SelectAndScatterOpConverter
SmallVector<Value, 4> ivs_val_flag_;
};
SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op,
SmallVector<Value, 2> SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,
scf::ParallelOp loop_over_src,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
@ -636,9 +635,10 @@ class SelectAndScatterOpConverter
return window_loops.selected_ivs;
}
SmallVector<Value, 4> SelectOrInitialize(
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> operand_ivs,
IterArgs* ivs_val_flag, OpBuilder* b) const {
SmallVector<Value, 4> SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,
ArrayRef<Value> operand_ivs,
IterArgs* ivs_val_flag,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
Value true_i1 = b->create<mlir::ConstantOp>(
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
@ -707,9 +707,9 @@ struct LhloLegalizeToParallelLoops
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
scf::SCFDialect, XlaLhloDialect>();
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
xla_lhlo::SelectAndScatterOp>();
scf::SCFDialect, LmhloDialect>();
target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
lmhlo::SelectAndScatterOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure();
@ -727,5 +727,5 @@ static PassRegistration<LhloLegalizeToParallelLoops> legalize_lhlo_pass(
"lhlo-legalize-to-parallel-loops",
"Legalize from LHLO dialect to parallel loops.");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -131,9 +131,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
loc, opResultTypes, args, args_count, results_count, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace.
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
// That method needs to be moved out of there.
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
Value opResult = lmhlo::XlaOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes,
llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
@ -162,8 +162,8 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
// Create two loads from the input.
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
Value opResult = lmhlo::XlaOpToStdScalarOp::map<LhloOp>(
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
@ -173,21 +173,21 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
};
//===----------------------------------------------------------------------===//
// xla_lhlo.convolution conversion pattern.
// lmhlo.convolution conversion pattern.
//===----------------------------------------------------------------------===//
/// Converts xla_lhlo.convolution operation to a linalg.conv op.
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
/// Converts lmhlo.convolution operation to a linalg.conv op.
struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
public:
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
// This code has been adapted from IREE's
// (https://github.com/google/iree/) mhlo -> linalg conversion.
LogicalResult matchAndRewrite(
xla_lhlo::ConvOp op, ArrayRef<Value> args,
lmhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information.
if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers =
if (const lmhlo::ConvDimensionNumbers& dimensionNumbers =
op.dimension_numbers()) {
const int inputSpatialRank =
llvm::size(dimensionNumbers.input_spatial_dimensions());
@ -388,14 +388,14 @@ class HloBroadcastInDimConverter
};
class LhloBroadcastInDimConverter
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
: public OpConversionPattern<lmhlo::BroadcastInDimOp> {
public:
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::BroadcastInDimOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
auto result_shape = result_type.getShape();
@ -444,9 +444,9 @@ class LhloBroadcastInDimConverter
// Inserts 'linalg.reshape' if there is a size-1 dim expansion.
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const {
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
Value operand = operand_adaptor.operand();
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
@ -512,7 +512,7 @@ class LhloBroadcastInDimConverter
return std::make_pair(operand, broadcast_dims);
}
SmallVector<AffineMap, 2> getIndexingMaps(xla_lhlo::BroadcastInDimOp op,
SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
ArrayRef<int64_t> broadcastDims,
ArrayRef<int64_t> resultShape,
MemRefType operandType,
@ -639,12 +639,12 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
}
};
class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
public:
using OpConversionPattern<xla_lhlo::IotaOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::IotaOp iotaOp, ArrayRef<Value> args,
lmhlo::IotaOp iotaOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto resultMemrefType =
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
@ -680,12 +680,12 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
}
};
class ConstConverter : public OpConversionPattern<xla_lhlo::ConstOp> {
class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
public:
using OpConversionPattern<xla_lhlo::ConstOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::ConstOp constOp, ArrayRef<Value> args,
lmhlo::ConstOp constOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = constOp.getLoc();
auto valueAttr = constOp.value().cast<DenseElementsAttr>();
@ -726,12 +726,12 @@ class ReverseConverter
}
};
class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
public:
using OpConversionPattern<xla_lhlo::SliceOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::SliceOp sliceOp, ArrayRef<Value> args,
lmhlo::SliceOp sliceOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = sliceOp.getLoc();
auto argType =
@ -763,50 +763,50 @@ class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
ConstConverter,
ConvToLinalgConverter,
IotaConverter,
LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
PointwiseToLinalgConverter<xla_lhlo::AddOp>,
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
PointwiseToLinalgConverter<xla_lhlo::CeilOp>,
PointwiseToLinalgConverter<xla_lhlo::CompareOp>,
PointwiseToLinalgConverter<xla_lhlo::ComplexOp>,
PointwiseToLinalgConverter<xla_lhlo::ConvertOp>,
PointwiseToLinalgConverter<lmhlo::AbsOp>,
PointwiseToLinalgConverter<lmhlo::AddOp>,
PointwiseToLinalgConverter<lmhlo::AndOp>,
PointwiseToLinalgConverter<lmhlo::CeilOp>,
PointwiseToLinalgConverter<lmhlo::CompareOp>,
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
PointwiseToLinalgConverter<lmhlo::ConvertOp>,
// TODO(ataei): Remove this pattern, CopyOp is folded away.
PointwiseToLinalgConverter<xla_lhlo::CopyOp>,
PointwiseToLinalgConverter<xla_lhlo::CosOp>,
PointwiseToLinalgConverter<xla_lhlo::DivOp>,
PointwiseToLinalgConverter<xla_lhlo::ExpOp>,
PointwiseToLinalgConverter<xla_lhlo::ImagOp>,
PointwiseToLinalgConverter<xla_lhlo::LogOp>,
PointwiseToLinalgConverter<xla_lhlo::MaxOp>,
PointwiseToLinalgConverter<xla_lhlo::MinOp>,
PointwiseToLinalgConverter<xla_lhlo::MulOp>,
PointwiseToLinalgConverter<xla_lhlo::NegOp>,
PointwiseToLinalgConverter<xla_lhlo::RealOp>,
PointwiseToLinalgConverter<xla_lhlo::RemOp>,
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
PointwiseToLinalgConverter<xla_lhlo::SignOp>,
PointwiseToLinalgConverter<xla_lhlo::SinOp>,
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
ReshapeOpConverter<xla_lhlo::ReshapeOp>,
ReverseConverter<xla_lhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
PointwiseToLinalgConverter<lmhlo::CopyOp>,
PointwiseToLinalgConverter<lmhlo::CosOp>,
PointwiseToLinalgConverter<lmhlo::DivOp>,
PointwiseToLinalgConverter<lmhlo::ExpOp>,
PointwiseToLinalgConverter<lmhlo::ImagOp>,
PointwiseToLinalgConverter<lmhlo::LogOp>,
PointwiseToLinalgConverter<lmhlo::MaxOp>,
PointwiseToLinalgConverter<lmhlo::MinOp>,
PointwiseToLinalgConverter<lmhlo::MulOp>,
PointwiseToLinalgConverter<lmhlo::NegOp>,
PointwiseToLinalgConverter<lmhlo::RealOp>,
PointwiseToLinalgConverter<lmhlo::RemOp>,
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
PointwiseToLinalgConverter<lmhlo::SelectOp>,
PointwiseToLinalgConverter<lmhlo::SignOp>,
PointwiseToLinalgConverter<lmhlo::SinOp>,
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
PointwiseToLinalgConverter<lmhlo::SubOp>,
PointwiseToLinalgConverter<lmhlo::TanhOp>,
ReshapeOpConverter<lmhlo::ReshapeOp>,
ReverseConverter<lmhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
SliceConverter
>(context);
// clang-format on
}
// Converts LHLO ops to Linalg generic.
// Sample result for xla_lhlo::AddOp.
// Sample result for lmhlo::AddOp.
//
// "xla_lhlo.add"(%arg1, %arg2, %out) :
// "lmhlo.add"(%arg1, %arg2, %out) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
//
// will be converted to
@ -854,14 +854,14 @@ struct HloLegalizeToLinalg
} // namespace
namespace xla_lhlo {
namespace lmhlo {
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
return absl::make_unique<LhloLegalizeToLinalg>();
}
static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
} // namespace xla_lhlo
} // namespace lmhlo
namespace mhlo {

View File

@ -7,7 +7,7 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_result = "mhlo.exponential"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -18,10 +18,10 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
return %arg0 : tensor<4xf32>
}
// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
// PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
// PRE-NEXT: return
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// ESC-NOT: "xla_lhlo.copy"
// ESC-NOT: "lmhlo.copy"
// ESC-NEXT: return %[[ARG0]]
// -----
@ -38,20 +38,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
//  BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
//  BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
// PRE-NEXT: return
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
@ -67,14 +67,14 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
// BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
tensor_store %tensor_result, %result : memref<2x2xf32>
// BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: return
@ -88,7 +88,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -100,7 +100,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.exponential"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -112,7 +112,7 @@ func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -127,7 +127,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
// BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -141,7 +141,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
// BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
tensor_store %tensor_result, %result : memref<2x2xi1>
return
}
@ -154,7 +154,7 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32>
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<10x5xf32>
return
}
@ -205,12 +205,12 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast
// BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
// BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
@ -229,7 +229,7 @@ func @complex(%real: memref<2x2xf32>,
%tensor_imag = tensor_load %imag : memref<2x2xf32>
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
return
}
@ -241,7 +241,7 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "mhlo.real"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -253,7 +253,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "mhlo.imag"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -264,7 +264,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
func @iota(%result: memref<10xi32>) {
%tensor_result = "mhlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
// BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32>
return
}
@ -276,7 +276,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -288,7 +288,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -300,7 +300,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
// BOTH-NOT: tensor_store
tensor_store %tensor_result, %result : memref<2x2xf32>
return
@ -313,7 +313,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.cosine"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -325,7 +325,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.negate"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -337,7 +337,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -349,7 +349,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -361,7 +361,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -373,7 +373,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -386,7 +386,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
// BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -412,7 +412,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
@ -437,7 +437,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
@ -448,10 +448,10 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "mhlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
// ESC: return %[[ALLOC]]
return %dot : tensor<1024x1024xf32>
}
@ -462,7 +462,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index
// BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// BOTH-SAME: padding = dense<[
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// BOTH-SAME: rhs_dilation = dense<[1, 2]>

View File

@ -3,10 +3,10 @@
// CHECK-LABEL: func @remove_simple
func @remove_simple(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
@ -14,9 +14,9 @@ func @remove_simple(%arg0: memref<2x2xf32>) {
// CHECK-LABEL: func @remove_without_dealloc
func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
@ -24,22 +24,22 @@ func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
// CHECK-LABEL: func @replace_dependency
func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @keep_copies
func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
// CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
// CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
@ -50,14 +50,14 @@ func @must_not_be_removed(%arg0: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
// CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32>
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
@ -67,13 +67,13 @@ func @must_be_removed_first(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
@ -83,11 +83,11 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}

View File

@ -10,18 +10,18 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
%src: memref<56x56xf32>,
%init: memref<f32>,
%result: memref<112x112xf32>) {
"xla_lhlo.select_and_scatter"(%arg, %src, %init, %result) ( {
"lmhlo.select_and_scatter"(%arg, %src, %init, %result) ( {
// select
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %pred: memref<i1>):
"xla_lhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} :
"lmhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} :
(memref<f32>, memref<f32>, memref<i1>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}, {
// scatter
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %out: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %out) :
"lmhlo.add"(%lhs, %rhs, %out) :
(memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}) {
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
@ -29,7 +29,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
} : (memref<112x112xf32>,
memref<56x56xf32>,
memref<f32>, memref<112x112xf32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// CHECK-LABEL: func @select_and_scatter(
// CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>,
@ -121,7 +121,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
// Compute PRED.
// CHECK: "xla_lhlo.compare"(
// CHECK: "lmhlo.compare"(
// CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
// CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1>
@ -182,7 +182,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
// Compute scatter value.
// CHECK: "xla_lhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32>

View File

@ -14,7 +14,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
// CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
// CHECK: return
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
(memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> ()
return
}
@ -24,7 +24,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: addf %{{.*}}, %{{.*}} : f32
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
"lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -32,7 +32,7 @@ func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: addi %{{.*}}, %{{.*}} : i32
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
"lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -42,7 +42,7 @@ func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: and %{{.*}}, %{{.*}} : i32
"xla_lhlo.and"(%lhs, %rhs, %result) {name = "and.1"}
"lmhlo.and"(%lhs, %rhs, %result) {name = "and.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -52,7 +52,7 @@ func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: divf %{{.*}}, %{{.*}} : f32
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
"lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -60,7 +60,7 @@ func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: divi_signed %{{.*}}, %{{.*}} : i32
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
"lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -71,7 +71,7 @@ func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
"lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -81,7 +81,7 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
"lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -92,7 +92,7 @@ func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -102,7 +102,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -112,7 +112,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: mulf %{{.*}}, %{{.*}} : f32
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
"lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -121,7 +121,7 @@ func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: muli %{{.*}}, %{{.*}} : i32
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
"lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -131,7 +131,7 @@ func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: subf %{{.*}}, %{{.*}} : f32
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
"lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -139,7 +139,7 @@ func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: subi %{{.*}}, %{{.*}} : i32
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
"lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -158,7 +158,7 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs:
// CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
// CHECK: return
"xla_lhlo.dot"(%lhs, %rhs, %result) :
"lmhlo.dot"(%lhs, %rhs, %result) :
(memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> ()
return
}
@ -175,7 +175,7 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
// CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
// CHECK: return
"xla_lhlo.dot"(%lhs, %rhs, %result) :
"lmhlo.dot"(%lhs, %rhs, %result) :
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
return
}

View File

@ -3,11 +3,11 @@
func @reduce(%arg: memref<100x10xf32>,
%init: memref<f32>,
%result: memref<100xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
"lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
"lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
return
@ -25,7 +25,7 @@ func @reduce(%arg: memref<100x10xf32>,
// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref<f32, #map0>
// CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref<f32, #map0>
// CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
// CHECK: }
// CHECK: gpu.terminator
// CHECK: }

View File

@ -4,7 +4,7 @@
// CHECK-LABEL: func @element_wise
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.add"(%lhs, %rhs, %result)
"lmhlo.add"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -19,7 +19,7 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>,
%rhs: memref<?x?xf32>,
%result: memref<?x?xf32>) {
"xla_lhlo.add"(%lhs, %rhs, %result)
"lmhlo.add"(%lhs, %rhs, %result)
: (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
@ -33,7 +33,7 @@ func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>,
// CHECK-LABEL: func @element_wise_scalar
func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
%result: memref<f32>) {
"xla_lhlo.add"(%lhs, %rhs, %result)
"lmhlo.add"(%lhs, %rhs, %result)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
return
}
@ -48,7 +48,7 @@ func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
// CHECK-LABEL: func @minf
func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.minimum"(%lhs, %rhs, %result)
"lmhlo.minimum"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -63,7 +63,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
// CHECK-LABEL: func @maxi
func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
"xla_lhlo.maximum"(%lhs, %rhs, %result)
"lmhlo.maximum"(%lhs, %rhs, %result)
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -78,7 +78,7 @@ func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
// CHECK-LABEL: func @and
func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
"xla_lhlo.and"(%lhs, %rhs, %result)
"lmhlo.and"(%lhs, %rhs, %result)
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -91,7 +91,7 @@ func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
// CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.exponential"(%input, %result)
"lmhlo.exponential"(%input, %result)
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -104,7 +104,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @log
func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -116,7 +116,7 @@ func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @copy
func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
"xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
"lmhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -128,7 +128,7 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
// CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xi1>) {
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> ()
return
}
@ -142,7 +142,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
// CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi1>) {
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> ()
return
}
@ -156,7 +156,7 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
// CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.select"(%pred, %lhs, %rhs, %result)
"lmhlo.select"(%pred, %lhs, %rhs, %result)
: (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -170,7 +170,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @iota
func @iota(%out: memref<7x10xf32>) {
"xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
"lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
return
}
// CHECK: linalg.indexed_generic
@ -186,7 +186,7 @@ func @iota(%out: memref<7x10xf32>) {
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
"xla_lhlo.broadcast"(%operand, %result) {
"lmhlo.broadcast"(%operand, %result) {
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
} : (memref<f32>, memref<4x2x1xf32>) -> ()
return
@ -203,7 +203,7 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
// CHECK-LABEL: func @broadcast
func @broadcast(%operand: memref<4x?x16xf32>,
%result: memref<4x2x1x4x?x16xf32>) {
"xla_lhlo.broadcast"(%operand, %result) {
"lmhlo.broadcast"(%operand, %result) {
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
} : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> ()
return
@ -220,7 +220,7 @@ func @broadcast(%operand: memref<4x?x16xf32>,
// CHECK-LABEL: func @dynamic_broadcast_in_dim
func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
%result: memref<?x?x?x?x?xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>
} : (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
return
@ -237,7 +237,7 @@ func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_no_expansion
func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
%result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<5xf32>, memref<5x10xf32>) -> ()
return
@ -255,7 +255,7 @@ func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_expansion
func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
%result: memref<5x10x100xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[2, 0]> : tensor<2xi64>
} : (memref<1x5xf32>, memref<5x10x100xf32>) -> ()
return
@ -274,7 +274,7 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_scalar
func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
%result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[]> : tensor<0xi64>
} : (memref<f32>, memref<5x10xf32>) -> ()
return
@ -291,7 +291,7 @@ func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one
func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
%result: memref<1x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<1xf32>, memref<1x5xf32>) -> ()
return
@ -307,7 +307,7 @@ func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many
func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
%result: memref<5x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[1]> : tensor<1xi64>
} : (memref<1xf32>, memref<5x5xf32>) -> ()
return
@ -323,7 +323,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
// CHECK-LABEL: func @constant
func @constant(%value: memref<i32>) {
"xla_lhlo.constant"(%value) {
"lmhlo.constant"(%value) {
value = dense<10> : tensor<i32>
} : (memref<i32>) -> ()
return
@ -335,7 +335,7 @@ func @constant(%value: memref<i32>) {
// CHECK-LABEL: func @absf
func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -348,7 +348,7 @@ func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @absi
func @absi(%input: memref<2x2xi32>,
%result: memref<2x2xi32>) {
"xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -364,7 +364,7 @@ func @absi(%input: memref<2x2xi32>,
// CHECK-LABEL: func @ceil
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -376,7 +376,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_i32_to_f32
func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -389,7 +389,7 @@ func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_i16_to_i32
func @convert_i16_to_i32(%input: memref<2x2xi16>,
%result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
@ -401,7 +401,7 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>,
// CHECK-LABEL: func @convert_i32_to_i16
func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> ()
return
}
// CHECK: linalg.generic
@ -413,7 +413,7 @@ func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) {
// CHECK-LABEL: func @convert_f32_to_f64
func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> ()
return
}
// CHECK: linalg.generic
@ -425,7 +425,7 @@ func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) {
// CHECK-LABEL: func @convert_f64_to_f32
func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -437,7 +437,7 @@ func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_i32_to_i32
func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
@ -448,7 +448,7 @@ func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @convert_f32_to_f32
func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -459,7 +459,7 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_f32_to_i32
func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result)
"lmhlo.convert"(%input, %result)
: (memref<2x2xf32>, memref<2x2xi32>) -> ()
return
}
@ -472,7 +472,7 @@ func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -485,7 +485,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.sine"(%input, %result)
"lmhlo.sine"(%input, %result)
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -498,7 +498,7 @@ func @sin(%input: memref<2x2xf32>,
// CHECK-LABEL: func @negf
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -510,7 +510,7 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @negi
func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
"xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
@ -524,7 +524,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @rem
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.remainder"(%lhs, %rhs, %result)
"lmhlo.remainder"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -537,7 +537,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
// CHECK-LABEL: func @rsqrt
func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -549,7 +549,7 @@ func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sign
func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -562,7 +562,7 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sqrt
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -574,7 +574,7 @@ func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @tanh
func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -588,7 +588,7 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @complex(%real: memref<2x2xf32>,
%imag: memref<2x2xf32>,
%cplx: memref<2x2xcomplex<f32>>) {
"xla_lhlo.complex"(%real, %imag, %cplx)
"lmhlo.complex"(%real, %imag, %cplx)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex<f32>>) -> ()
return
}
@ -602,7 +602,7 @@ func @complex(%real: memref<2x2xf32>,
// CHECK-LABEL: func @real
func @real(%cplx: memref<2x2xcomplex<f32>>,
%real: memref<2x2xf32>) {
"xla_lhlo.real"(%cplx, %real)
"lmhlo.real"(%cplx, %real)
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
return
}
@ -616,7 +616,7 @@ func @real(%cplx: memref<2x2xcomplex<f32>>,
// CHECK-LABEL: func @imag
func @imag(%cplx: memref<2x2xcomplex<f32>>,
%imag: memref<2x2xf32>) {
"xla_lhlo.imag"(%cplx, %imag)
"lmhlo.imag"(%cplx, %imag)
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
return
}
@ -629,7 +629,7 @@ func @imag(%cplx: memref<2x2xcomplex<f32>>,
// CHECK: func @slice(%[[IN:.*]]: memref<?x?xf32>, %[[OUT:.*]]: memref<?x?xf32>)
func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
"xla_lhlo.slice"(%operand, %result) {
"lmhlo.slice"(%operand, %result) {
start_indices = dense<[0,1]> : tensor<2xi64>,
limit_indices = dense<[2,3]> : tensor<2xi64>,
strides = dense<[1,1]> : tensor<2xi64>
@ -653,7 +653,7 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
"lmhlo.reshape"(%arg0, %arg1)
: (memref<12x1x42xi32>, memref<12x42xi32>) -> ()
return
}
@ -666,7 +666,7 @@ func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
"lmhlo.reshape"(%arg0, %arg1)
: (memref<12x42x1x1xi32>, memref<12x42xi32>) -> ()
return
}
@ -679,7 +679,7 @@ func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
"lmhlo.reshape"(%arg0, %arg1)
: (memref<12x42xi32>, memref<12x1x42x1xi32>) -> ()
return
}
@ -692,7 +692,7 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse
func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
"xla_lhlo.reverse"(%arg0, %arg1) {
"lmhlo.reverse"(%arg0, %arg1) {
dimensions = dense<1> : tensor<1xi64>
} : (memref<2x3xf32>, memref<2x3xf32>) -> ()
return
@ -710,15 +710,15 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: strides = [2, 1]}
// With all atributes explicitly specified.
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
"lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
// Dilation left unspecified, sets default dilation since linalg expects it.
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: dilations = [1, 1]
// Padding is not set if it's zero.
// CHECK-NOT: padding
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
"lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}

View File

@ -2,7 +2,7 @@
// CHECK-LABEL: func @static_memref_cast
func @static_memref_cast(%buf : memref<10x1x5xf32>) {
%0 = xla_lhlo.static_memref_cast %buf
%0 = lmhlo.static_memref_cast %buf
: memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]>
return
}
@ -38,7 +38,7 @@ func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
%size_Y = constant 50 : index
%stride_X = constant 1 : index
%stride_Y = constant 0 : index
%0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
%0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return
}

View File

@ -3,11 +3,11 @@
func @reduce(%arg: memref<100x10x5xf32>,
%init: memref<f32>,
%result: memref<100x5xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
"lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
"lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
return
@ -35,7 +35,7 @@ func @reduce(%arg: memref<100x10x5xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
@ -49,11 +49,11 @@ func @reduce(%arg: memref<100x10x5xf32>,
func @reduce_no_outer_loop(%arg: memref<100xf32>,
%init: memref<f32>,
%result: memref<1xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
"lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
"lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[0]> : tensor<1xi64>}
: (memref<100xf32>, memref<f32>, memref<1xf32>) -> ()
return
@ -76,7 +76,7 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]]
// CHECK: }
@ -88,11 +88,11 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
func @dynamic_reduce(%arg: memref<?x?x?xf32>,
%init: memref<f32>,
%result: memref<?x?xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
"lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
"lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<?x?x?xf32>, memref<f32>, memref<?x?xf32>) -> ()
return
@ -121,7 +121,7 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
@ -135,11 +135,11 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
func @reduce_window(%arg: memref<112x112xf32>,
%init: memref<f32>,
%result: memref<56x56xf32>) {
"xla_lhlo.reduce_window"(%arg, %init, %result) ( {
"lmhlo.reduce_window"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.maximum"(%lhs, %rhs, %res)
"lmhlo.maximum"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}) {
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
@ -189,7 +189,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }

View File

@ -4,7 +4,7 @@
// CHECK-LABEL: func @ceil
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -12,7 +12,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -20,7 +20,7 @@ func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -28,7 +28,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
"lmhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
@ -36,7 +36,7 @@ func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -44,7 +44,7 @@ func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -52,7 +52,7 @@ func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
"lmhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
@ -60,7 +60,7 @@ func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -68,7 +68,7 @@ func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @add_memrefs
func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -76,7 +76,7 @@ func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1
// CHECK-LABEL: func @abs_memref
func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -84,7 +84,7 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @convert_memref
func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
"lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
return
}
@ -92,7 +92,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
"lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
return
}
@ -100,7 +100,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
// CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -108,7 +108,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
"lmhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
@ -116,7 +116,7 @@ func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -124,7 +124,7 @@ func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @log_memref
func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -132,7 +132,7 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @log_memref
func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
"lmhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
@ -140,7 +140,7 @@ func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) ->
func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -148,7 +148,7 @@ func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @neg_memref
func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -156,7 +156,7 @@ func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @rsqrt_memref
func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -164,7 +164,7 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @rsqrt_memref
func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
"lmhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
@ -172,7 +172,7 @@ func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>)
func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -180,7 +180,7 @@ func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @sqrt_memref
func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -188,7 +188,7 @@ func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @sqrt_memref
func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
"lmhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
@ -196,7 +196,7 @@ func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -
func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -204,7 +204,7 @@ func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @sign_memref
func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -212,7 +212,7 @@ func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @tanh_memref
func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -220,7 +220,7 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @tanh_memref
func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
"lmhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
@ -228,15 +228,15 @@ func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -
func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
"xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
// expected-error@+1{{'lmhlo.tanh' op requires all operands to have the same type}}
"lmhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
return
}
@ -244,7 +244,7 @@ func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
// CHECK-LABEL: func @add_memref
func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -252,7 +252,7 @@ func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @div_memref
func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -260,7 +260,7 @@ func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @max_memref
func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -268,7 +268,7 @@ func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @min_memref
func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -276,7 +276,7 @@ func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @mul_memref
func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -284,7 +284,7 @@ func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @sub_memref
func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -292,7 +292,7 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @and_memref
func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -300,7 +300,7 @@ func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32
// CHECK-LABEL: func @and_memref
func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
@ -308,7 +308,7 @@ func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>)
func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -316,7 +316,7 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @or_memref
func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -324,7 +324,7 @@ func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>
// CHECK-LABEL: func @or_memref
func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
@ -332,7 +332,7 @@ func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -
func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -340,7 +340,7 @@ func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>
// CHECK-LABEL: func @xor_memref
func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -348,7 +348,7 @@ func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32
// CHECK-LABEL: func @xor_memref
func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
@ -356,7 +356,7 @@ func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>)
func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -364,7 +364,7 @@ func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @broadcast_in_dim_memref
func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
"lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
return
}
@ -372,7 +372,7 @@ func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -
// CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref
func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
"lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
return
}
@ -381,10 +381,10 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi
// CHECK-LABEL: func @reduce_memref
func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
"xla_lhlo.reduce"(%input, %init, %out) ( {
"lmhlo.reduce"(%input, %init, %out) ( {
^bb0(%arg1: memref<f32>, %arg2: memref<f32>, %result: memref<f32>):
"xla_lhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref<f32>, memref<1xf32>) -> ()
return
}
@ -393,14 +393,14 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
// CHECK-LABEL: func @fusion_memref
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.fusion"() ( {
"lmhlo.fusion"() ( {
%0 = tensor_load %input1 : memref<10xf32>
%1 = tensor_load %input2 : memref<10xf32>
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%3 = tensor_load %input3 : memref<10xf32>
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
tensor_store %4, %out : memref<10xf32>
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
} ) : () -> ()
return
}
@ -409,18 +409,18 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
// CHECK-LABEL: func @case_memref
func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
"xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
"lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
^bb0(%arg0: memref<f32>):
"xla_lhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}, {
^bb0(%arg0: memref<f32>):
"xla_lhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}, {
^bb0(%arg0: memref<f32>):
"xla_lhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}
) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>}
: (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
@ -430,7 +430,7 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr
// -----
func @static_memref_cast(%in: memref<10x1xf32>) {
%out = xla_lhlo.static_memref_cast %in
%out = lmhlo.static_memref_cast %in
: memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]>
return
}
@ -440,7 +440,7 @@ func @static_memref_cast(%in: memref<10x1xf32>) {
func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
// expected-error @+1 {{operand must have static shape}}
%out = xla_lhlo.static_memref_cast %in
%out = lmhlo.static_memref_cast %in
: memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]>
return
}
@ -449,7 +449,7 @@ func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
// expected-error @+1 {{result must have static shape}}
%out = xla_lhlo.static_memref_cast %in
%out = lmhlo.static_memref_cast %in
: memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]>
return
}
@ -459,7 +459,7 @@ func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
func @dynamic_memref_cast(%in: memref<?xf32>) {
%size = constant 10 : index
%step = constant 1 : index
%out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
: memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]>
return
}
@ -471,7 +471,7 @@ func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
// expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}}
%size = constant 10 : index
%step = constant 1 : index
%out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return
}
@ -483,19 +483,19 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
// CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>,
// CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref<?xi32>
// CHECK-NEXT: [[DYN_VEC:%.*]] = xla_lhlo.reshape_memref_cast [[UNRANKED]]
// CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]]
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
%dyn_vec = xla_lhlo.reshape_memref_cast %unranked(%shape1)
%dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1)
: (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
// CHECK-NEXT: [[DYN_MAT:%.*]] = xla_lhlo.reshape_memref_cast [[DYN_VEC]]
// CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]]
// CHECK-SAME: : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
%dyn_mat = xla_lhlo.reshape_memref_cast %dyn_vec(%shape2)
%dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2)
: (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
// CHECK-NEXT: {{%.*}} = xla_lhlo.reshape_memref_cast [[DYN_MAT]]
// CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]]
// CHECK-SAME: : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
%new_unranked = xla_lhlo.reshape_memref_cast %dyn_mat(%shape3)
%new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3)
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
return
}
@ -505,7 +505,7 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
func @reshape_memref_cast_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}}
xla_lhlo.reshape_memref_cast %buf(%shape)
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<1xi32>) -> memref<?xi32>
}
@ -514,7 +514,7 @@ func @reshape_memref_cast_element_type_mismatch(
func @reshape_memref_cast_dst_ranked_shape_unranked(
%buf: memref<*xf32>, %shape: memref<?xi32>) {
// expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}}
xla_lhlo.reshape_memref_cast %buf(%shape)
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<?xi32>) -> memref<?xf32>
return
}
@ -524,7 +524,7 @@ func @reshape_memref_cast_dst_ranked_shape_unranked(
func @reshape_memref_cast_dst_shape_rank_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{length of shape operand differs from the result's memref rank}}
xla_lhlo.reshape_memref_cast %buf(%shape)
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32>
return
}
@ -535,7 +535,7 @@ func @reshape_memref_cast_affine_map_is_not_identity(
%buf: memref<4x4xf32, offset: 0, strides: [3, 2]>,
%shape: memref<1xi32>) {
// expected-error @+1 {{operand memref type should have identity affine map}}
xla_lhlo.reshape_memref_cast %buf(%shape)
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>)
-> memref<8xf32>
return
@ -545,7 +545,7 @@ func @reshape_memref_cast_affine_map_is_not_identity(
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -553,7 +553,7 @@ func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
@ -561,7 +561,7 @@ func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>
func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -569,7 +569,7 @@ func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref
// CHECK-LABEL: func @bitcast_convert_memrefs
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
"lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
return
}
@ -577,7 +577,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) ->
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
"lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
return
}
@ -585,7 +585,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) ->
// CHECK-LABEL: func @clz_memrefs
func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -593,7 +593,7 @@ func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -601,7 +601,7 @@ func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
"lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
@ -609,7 +609,7 @@ func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f3
// CHECK-LABEL: func @floor_memrefs
func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -617,7 +617,7 @@ func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -625,7 +625,7 @@ func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @imag_memrefs
func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
"lmhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
return
}
@ -633,7 +633,7 @@ func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> ()
func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error@+1{{must be memref of complex-type values}}
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -641,7 +641,7 @@ func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @real_memrefs
func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
"lmhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
return
}
@ -649,7 +649,7 @@ func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> ()
func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error@+1{{must be memref of complex-type values}}
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -657,7 +657,7 @@ func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @is_finite_memrefs
func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
"xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
"lmhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
return
}
@ -665,7 +665,7 @@ func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
// CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -673,7 +673,7 @@ func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
"lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
@ -681,7 +681,7 @@ func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f3
func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -689,7 +689,7 @@ func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @not_memrefs
func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -697,7 +697,7 @@ func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @not_memrefs
func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
"lmhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
return
}
@ -705,7 +705,7 @@ func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -713,7 +713,7 @@ func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @popcnt_memrefs
func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -721,7 +721,7 @@ func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -729,7 +729,7 @@ func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @reduce_precision_memrefs
func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -737,7 +737,7 @@ func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) ->
// CHECK-LABEL: func @round_memrefs
func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -745,7 +745,7 @@ func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -753,7 +753,7 @@ func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @shift_left_memrefs
func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -761,7 +761,7 @@ func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: m
func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -769,7 +769,7 @@ func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: m
// CHECK-LABEL: func @shift_right_arithmetic_memrefs
func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -777,7 +777,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>,
func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -785,7 +785,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>,
// CHECK-LABEL: func @shift_right_logical_memrefs
func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -793,7 +793,7 @@ func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %a
func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -801,14 +801,14 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
// CHECK-LABEL: func @all_reduce_memrefs
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
"lmhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
"mhlo.return"(%max) : (tensor<f32>) -> ()
})
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
"lmhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
"mhlo.return"(%max) : (tensor<f32>) -> ()
@ -826,11 +826,11 @@ func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> ()
// CHECK-LABEL: func @collective_permute_memrefs
func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
"lmhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
"lmhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
channel_id = { handle = 5 : i64, type = 2 : i64 }
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
@ -841,7 +841,7 @@ func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128
// CHECK-LABEL: func @fft_memrefs
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
"xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
"lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
return
}
@ -852,7 +852,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
%grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>,
%grad_offset: memref<8xf32>) -> () {
"xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
"lmhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
return
@ -863,7 +863,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>,
// CHECK-LABEL: func @batch_norm_inference_memrefs
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
"xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
"lmhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
return
}
@ -874,7 +874,7 @@ func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>,
%batch_var: memref<8xf32>) -> () {
"xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
"lmhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
return
}
@ -883,8 +883,8 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
// CHECK-LABEL: func @cholesky_memrefs
func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () {
"xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
"xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
"lmhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
"lmhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
return
}
@ -892,7 +892,7 @@ func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291x
// CHECK-LABEL: func @infeed_memrefs
func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
"xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
"lmhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
return
}
@ -900,7 +900,7 @@ func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
// CHECK-LABEL: func @outfeed_memrefs
func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
"xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
"lmhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
return
}
@ -908,7 +908,7 @@ func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
// CHECK-LABEL: func @replica_id_memrefs
func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
"xla_lhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
"lmhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
return
}
@ -916,7 +916,7 @@ func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
// CHECK-LABEL: func @triangular_solve_memrefs
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
"xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
return
}
@ -925,9 +925,9 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %
// CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
"xla_lhlo.while"(%arg0, %arg_out) (
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "xla_lhlo.terminator"() : () -> () }
"lmhlo.while"(%arg0, %arg_out) (
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "lmhlo.terminator"() : () -> () }
) : (memref<i64>, memref<i64>) -> ()
return
}
@ -936,9 +936,9 @@ func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
// CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>) -> () {
"xla_lhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) (
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "xla_lhlo.terminator"() : () -> () }
"lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) (
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () }
) : (memref<i64>, memref<5xf32>, memref<i64>, memref<5xf32>) -> ()
return
}
@ -947,7 +947,7 @@ func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<
// CHECK-LABEL: func @bitcast_memrefs
func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
"xla_lhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
"lmhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
return
}
@ -956,7 +956,7 @@ func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
// CHECK-LABEL: func @scatter_memrefs
func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>,
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
"lmhlo.scatter" (%input, %indices, %updates, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
%add = mhlo.add %lhs, %rhs : tensor<f32>
"mhlo.return"(%add) : (tensor<f32>) -> ()
@ -977,7 +977,7 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
// CHECK-LABEL: func @map_memrefs
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
"lmhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = mhlo.add %a, %b : tensor<f32>
"mhlo.return"(%c) : (tensor<f32>) -> ()
@ -989,7 +989,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
"lmhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = mhlo.add %a, %b : tensor<f32>
"mhlo.return"(%c) : (tensor<f32>) -> ()
@ -1001,7 +1001,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
// CHECK-LABEL: func @rng_get_and_update_state_memrefs
func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
"xla_lhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> ()
"lmhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> ()
return
}
@ -1010,7 +1010,7 @@ func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
@ -1023,7 +1023,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
@ -1036,7 +1036,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()