Rename xla_lhlo dialect into lmhlo

Following on the plan of isolating the compiler/mlir/hlo directory.
Another xla_lhlo dialect will be created under compiler/mlir/xla/ later.

PiperOrigin-RevId: 320210326
This commit is contained in:
Mehdi Amini 2020-07-08 17:05:32 +00:00 committed by Mehdi Amini
parent b076e018a8
commit 7c4a5d62b5
26 changed files with 566 additions and 550 deletions

View File

@ -35,18 +35,18 @@ class OpBuilder;
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc"
namespace xla_lhlo { namespace lmhlo {
class XlaLhloDialect : public Dialect { class LmhloDialect : public Dialect {
public: public:
explicit XlaLhloDialect(MLIRContext *context); explicit LmhloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "xla_lhlo"; } static StringRef getDialectNamespace() { return "lmhlo"; }
}; };
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
} // namespace xla_lhlo } // namespace lmhlo
} // end namespace mlir } // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_

View File

@ -38,8 +38,8 @@ include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInte
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
def LHLO_Dialect : Dialect { def LHLO_Dialect : Dialect {
let name = "xla_lhlo"; let name = "lmhlo";
let cppNamespace = "xla_lhlo"; let cppNamespace = "lmhlo";
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -253,7 +253,7 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example, // TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example,
// A tuple-like pattern match syntax could work: // A tuple-like pattern match syntax could work:
// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) { // lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
// ... // ...
// }, { // }, {
// ... // ...
@ -337,7 +337,7 @@ def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
Example: Example:
```mlir ```mlir
%buf_transformed = %buf_transformed =
xla_lhlo.static_memref_cast %buf lmhlo.static_memref_cast %buf
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]> : memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and // The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
@ -379,7 +379,7 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
Example: Example:
```mlir ```mlir
%buf_transformed = %buf_transformed =
xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y] lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]> : memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
// The result of the op is a type-erased memref with `[%size_X, %size_Y]` // The result of the op is a type-erased memref with `[%size_X, %size_Y]`
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited // shape and `[%step_X, %step_Y]` strides. The offset will be inherited

View File

@ -34,7 +34,7 @@ using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
#define MAP_HLO_TO_LHLO(OpName) \ #define MAP_HLO_TO_LHLO(OpName) \
template <> \ template <> \
struct HloToLhloOpImpl<mhlo::OpName> { \ struct HloToLhloOpImpl<mhlo::OpName> { \
using Type = xla_lhlo::OpName; \ using Type = lmhlo::OpName; \
} }
MAP_HLO_TO_LHLO(AbsOp); MAP_HLO_TO_LHLO(AbsOp);

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
namespace mlir { namespace mlir {
namespace xla_lhlo { namespace lmhlo {
namespace impl { namespace impl {
// A struct to map LhloBinaryOpTy type to the corresponding floating-point and // A struct to map LhloBinaryOpTy type to the corresponding floating-point and
@ -33,32 +33,32 @@ template <typename LhloBinaryOpTy>
struct LhloToScalarOp; struct LhloToScalarOp;
template <> template <>
struct LhloToScalarOp<xla_lhlo::AddOp> { struct LhloToScalarOp<lmhlo::AddOp> {
using FOp = ::mlir::AddFOp; using FOp = ::mlir::AddFOp;
using IOp = ::mlir::AddIOp; using IOp = ::mlir::AddIOp;
}; };
template <> template <>
struct LhloToScalarOp<xla_lhlo::CompareOp> { struct LhloToScalarOp<lmhlo::CompareOp> {
using FOp = ::mlir::CmpFOp; using FOp = ::mlir::CmpFOp;
using IOp = ::mlir::CmpIOp; using IOp = ::mlir::CmpIOp;
}; };
template <> template <>
struct LhloToScalarOp<xla_lhlo::DivOp> { struct LhloToScalarOp<lmhlo::DivOp> {
using FOp = ::mlir::DivFOp; using FOp = ::mlir::DivFOp;
using IOp = ::mlir::SignedDivIOp; using IOp = ::mlir::SignedDivIOp;
}; };
template <> template <>
struct LhloToScalarOp<xla_lhlo::MulOp> { struct LhloToScalarOp<lmhlo::MulOp> {
using FOp = ::mlir::MulFOp; using FOp = ::mlir::MulFOp;
using IOp = ::mlir::MulIOp; using IOp = ::mlir::MulIOp;
}; };
template <> template <>
struct LhloToScalarOp<xla_lhlo::RemOp> { struct LhloToScalarOp<lmhlo::RemOp> {
using FOp = ::mlir::RemFOp; using FOp = ::mlir::RemFOp;
using IOp = ::mlir::SignedRemIOp; using IOp = ::mlir::SignedRemIOp;
}; };
template <> template <>
struct LhloToScalarOp<xla_lhlo::SubOp> { struct LhloToScalarOp<lmhlo::SubOp> {
using FOp = ::mlir::SubFOp; using FOp = ::mlir::SubFOp;
using IOp = ::mlir::SubIOp; using IOp = ::mlir::SubIOp;
}; };
@ -116,8 +116,9 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) { if (element_type.isa<FloatType>()) {
@ -125,7 +126,7 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
loc, result_types, args, b); loc, result_types, args, b);
} }
if (element_type.isa<IntegerType>()) { if (element_type.isa<IntegerType>()) {
// xla_lhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) // lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
Value lhs = args[0]; Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>(); auto integer_type = element_type.dyn_cast<IntegerType>();
@ -133,15 +134,16 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge, auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
lhs, zero_intval); lhs, zero_intval);
auto neg_val = b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs); auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val); return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val);
} }
return nullptr; return nullptr;
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}( return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
@ -205,30 +207,33 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc,
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return args.front(); return args.front();
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args, return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args,
@ -236,21 +241,23 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RealOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b); return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b);
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ImagOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b); return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b);
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type sourceType = args.front().getType(); Type sourceType = args.front().getType();
@ -288,8 +295,9 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
// Dot Op converter from lhlo to affine only accepts float and integer types. // Dot Op converter from lhlo to affine only accepts float and integer types.
const auto& lhs = args[0]; const auto& lhs = args[0];
@ -312,16 +320,18 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SinOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
@ -361,38 +371,40 @@ struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
}; };
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp< return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType, IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types,
result_types, args, args, b);
b);
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp< return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType, IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types,
result_types, args, args, b);
b);
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) { if (element_type.isa<FloatType>()) {
@ -400,27 +412,28 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
loc, result_types, args, b); loc, result_types, args, b);
} }
if (element_type.isa<IntegerType>()) { if (element_type.isa<IntegerType>()) {
// xla_lhlo.neg(x, result) -> result = sub(0, x) // lmhlo.neg(x, result) -> result = sub(0, x)
Value lhs = args[0]; Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>(); auto integer_type = element_type.dyn_cast<IntegerType>();
auto zero_intval = auto zero_intval =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
return b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs); return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
} }
return nullptr; return nullptr;
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RsqrtOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
@ -428,8 +441,9 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) { if (element_type.isa<FloatType>()) {
@ -442,16 +456,18 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SqrtOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
} }
template <> template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>( inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
loc, result_types, args, b); loc, result_types, args, b);
@ -460,10 +476,10 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
} // namespace impl } // namespace impl
struct XlaOpToStdScalarOp { struct XlaOpToStdScalarOp {
// Implementation for LHLO ops except xla_lhlo::CompareOp. // Implementation for LHLO ops except lmhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy, template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
typename = std::enable_if_t< typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value && !std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>, std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>> std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types, static Value map(XlaOpTy op, ArrayRef<Type> result_types,
@ -475,7 +491,7 @@ struct XlaOpToStdScalarOp {
// Implementation for HLO ops except mhlo::CompareOp. // Implementation for HLO ops except mhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>, template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
typename = std::enable_if_t< typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value && !std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, std::false_type>::value>> !std::is_same<LhloOpTy, std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types, static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, int i = 0) { ArrayRef<Value> args, OpBuilder* b, int i = 0) {
@ -483,13 +499,13 @@ struct XlaOpToStdScalarOp {
args, b); args, b);
} }
// Implementation for xla_lhlo::CompareOp. // Implementation for lmhlo::CompareOp.
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same< template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
LhloOpTy, xla_lhlo::CompareOp>::value>> LhloOpTy, lmhlo::CompareOp>::value>>
static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types, static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction(); auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>( return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b); op.getLoc(), comparison_direction, result_types, args, b);
} }
@ -500,12 +516,12 @@ struct XlaOpToStdScalarOp {
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types, static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction(); auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>( return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b); op.getLoc(), comparison_direction, result_types, args, b);
} }
}; };
} // namespace xla_lhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_

View File

@ -60,7 +60,7 @@ std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
} // namespace mhlo } // namespace mhlo
namespace xla_lhlo { namespace lmhlo {
// Lowers from LHLO dialect to Affine dialect. // Lowers from LHLO dialect to Affine dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass();
@ -92,7 +92,7 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
// Lowers from LHLO dialect to parallel loops. // Lowers from LHLO dialect to parallel loops.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
} // namespace xla_lhlo } // namespace lmhlo
namespace xla { namespace xla {

View File

@ -75,14 +75,14 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
} // namespace mhlo } // namespace mhlo
namespace xla_lhlo { namespace lmhlo {
/// Collect a set of patterns to convert from the LHLO dialect to LLVM. /// Collect a set of patterns to convert from the LHLO dialect to LLVM.
void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
LLVMTypeConverter *converter, LLVMTypeConverter *converter,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);
} // namespace xla_lhlo } // namespace lmhlo
namespace xla_chlo { namespace xla_chlo {

View File

@ -21,4 +21,4 @@ limitations under the License.
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops; static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect> static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
xla_chlo_ops; xla_chlo_ops;
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops; static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops;

View File

@ -46,9 +46,9 @@ limitations under the License.
namespace mlir { namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc"
namespace xla_lhlo { namespace lmhlo {
XlaLhloDialect::XlaLhloDialect(MLIRContext *context) LmhloDialect::LmhloDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) { : Dialect(getDialectNamespace(), context) {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
@ -138,5 +138,5 @@ void FusionOp::build(OpBuilder &builder, OperationState &result,
FusionOp::ensureTerminator(*bodyRegion, builder, result.location); FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
} }
} // namespace xla_lhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -44,7 +44,7 @@ template <typename T>
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>; using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
using StdReturnOpConverter = using StdReturnOpConverter =
detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp, detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
xla_lhlo::CopyOp, true>; lmhlo::CopyOp, true>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand, Value shape_operand,
@ -149,7 +149,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
Value transformed_operand = Value transformed_operand =
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
rewriter.create<xla_lhlo::BroadcastInDimOp>( rewriter.create<lmhlo::BroadcastInDimOp>(
loc, transformed_operand, resultBuffer, op.broadcast_dimensions()); loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
rewriter.replaceOp(op, {resultBuffer}); rewriter.replaceOp(op, {resultBuffer});
@ -161,7 +161,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
// Inserts dynamic memref to change the layout of the memref to put 0-stride // Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is // and size of the target dimension if size-1 dimension expansion is
// necessary. // necessary.
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc(); auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>(); auto operand_type = operand.getType().cast<MemRefType>();
@ -214,7 +214,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
makeStridedLinearLayoutMap(dynamic_layout, makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext())); /*offset=*/0, b->getContext()));
auto transformed_operand = b->create<xla_lhlo::DynamicMemRefCastOp>( auto transformed_operand = b->create<lmhlo::DynamicMemRefCastOp>(
loc, type_erased_memref_type, operand, sizes, strides); loc, type_erased_memref_type, operand, sizes, strides);
return transformed_operand; return transformed_operand;
} }
@ -239,7 +239,7 @@ struct HloToLhloDynamicReshapeConverter
return failure(); return failure();
} }
mhlo::DynamicReshapeOp::Adaptor adaptor(operands); mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<xla_lhlo::ReshapeMemRefCastOp>( rewriter.replaceOpWithNewOp<lmhlo::ReshapeMemRefCastOp>(
op, result_type, adaptor.operand(), adaptor.output_shape()); op, result_type, adaptor.operand(), adaptor.output_shape());
return success(); return success();
} }
@ -266,8 +266,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
buffer_args.push_back( buffer_args.push_back(
InsertAlloc(loc, result, this->bufferAssignment, &rewriter)); InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
} }
auto new_op = rewriter.create<xla_lhlo::ReduceOp>( auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
loc, llvm::None, buffer_args, op.getAttrs()); op.getAttrs());
// Copy over the operations inside the region. // Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end()); rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
@ -292,7 +292,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
} }
// Insert terminator at the end. // Insert terminator at the end.
rewriter.setInsertionPointToEnd(&entry_block); rewriter.setInsertionPointToEnd(&entry_block);
rewriter.create<xla_lhlo::TerminatorOp>(loc); rewriter.create<lmhlo::TerminatorOp>(loc);
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size())); rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
@ -321,8 +321,8 @@ class HloToLhloTensorStoreOpConverter
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
mlir::TensorStoreOp op, ArrayRef<Value> operands, mlir::TensorStoreOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>( rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
op, llvm::None, operands.front(), operands.back()); operands.back());
return success(); return success();
} }
}; };
@ -336,7 +336,7 @@ class HloToLhloTensorStoreOpConverter
// %arg1: memref<2x2xf32>, // %arg1: memref<2x2xf32>,
// %arg2: memref<2x2xf32>, // %arg2: memref<2x2xf32>,
// %arg3: memref<2x2xf32>) { // %arg3: memref<2x2xf32>) {
// "xla_lhlo.fusion"() ({ // "lmhlo.fusion"() ({
// %0 = tensor_load %arg1 : memref<2x2xf32> // %0 = tensor_load %arg1 : memref<2x2xf32>
// %1 = tensor_load %arg2 : memref<2x2xf32> // %1 = tensor_load %arg2 : memref<2x2xf32>
// %2 = "mhlo.add"(%0, %1) : // %2 = "mhlo.add"(%0, %1) :
@ -345,7 +345,7 @@ class HloToLhloTensorStoreOpConverter
// %4 = "mhlo.multiply"(%2, %3) : // %4 = "mhlo.multiply"(%2, %3) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// tensor_store %4, %arg3 : memref<2x2xf32> // tensor_store %4, %arg3 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> () // "lmhlo.terminator"() : () -> ()
// }) : () -> () // }) : () -> ()
// return // return
// } // }
@ -355,13 +355,13 @@ class HloToLhloTensorStoreOpConverter
// %arg1: memref<2x2xf32>, // %arg1: memref<2x2xf32>,
// %arg2: memref<2x2xf32>, // %arg2: memref<2x2xf32>,
// %arg3: memref<2x2xf32>) { // %arg3: memref<2x2xf32>) {
// "xla_lhlo.fusion"() ( { // "lmhlo.fusion"() ( {
// %0 = alloc() : memref<2x2xf32> // %0 = alloc() : memref<2x2xf32>
// "xla_lhlo.add"(%arg1, %arg2, %0) : // "lmhlo.add"(%arg1, %arg2, %0) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.multiply"(%0, %arg0, %arg3) : // "lmhlo.multiply"(%0, %arg0, %arg3) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.terminator"() : () -> () // "lmhlo.terminator"() : () -> ()
// }) : () -> () // }) : () -> ()
// return // return
// } // }
@ -382,13 +382,13 @@ class HloToLhloTensorStoreOpConverter
// %arg2: memref<4xf32>) { // %arg2: memref<4xf32>) {
// %0 = alloc() : memref<4xf32> // %0 = alloc() : memref<4xf32>
// "xla_lhlo.maximum"(%arg0, %arg1, %0) : // "lmhlo.maximum"(%arg0, %arg1, %0) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// %1 = alloc() : memref<4xf32> // %1 = alloc() : memref<4xf32>
// "xla_lhlo.add"(%arg0, %0, %1) : // "lmhlo.add"(%arg0, %0, %1) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () // "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.terminator"() : () -> () // "lmhlo.terminator"() : () -> ()
// } // }
struct HloLegalizeToLhlo struct HloLegalizeToLhlo
@ -406,7 +406,7 @@ struct HloLegalizeToLhlo
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto& context = getContext(); auto& context = getContext();
ConversionTarget target(context); ConversionTarget target(context);
target.addLegalDialect<xla_lhlo::XlaLhloDialect>(); target.addLegalDialect<lmhlo::LmhloDialect>();
target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<ModuleOp>(); target.addLegalOp<ModuleOp>();
target.addIllegalOp<mlir::TensorLoadOp>(); target.addIllegalOp<mlir::TensorLoadOp>();
@ -441,12 +441,12 @@ struct HloLegalizeToLhlo
&converter, &patterns); &converter, &patterns);
if (results_escape_function) { if (results_escape_function) {
populateWithBufferAssignmentOpConversionPatterns< populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp, mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp,
/*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment, /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
&converter, &patterns); &converter, &patterns);
} else { } else {
populateWithBufferAssignmentOpConversionPatterns< populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp, mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp,
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment, /*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
&converter, &patterns); &converter, &patterns);
} }

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir { namespace mlir {
namespace xla_lhlo { namespace lmhlo {
namespace { namespace {
// Removes LHLO copy operations that copy from allocated buffers to block // Removes LHLO copy operations that copy from allocated buffers to block
@ -34,7 +34,7 @@ struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
void runOnOperation() override { void runOnOperation() override {
llvm::SmallVector<mlir::Operation*, 2> eraseList; llvm::SmallVector<mlir::Operation*, 2> eraseList;
auto operation = getOperation(); auto operation = getOperation();
operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) { operation->walk([&](mlir::lmhlo::CopyOp copyOp) {
// If this region contains more than one block, then ignore this copy // If this region contains more than one block, then ignore this copy
// operation. // operation.
if (copyOp.getParentRegion()->getBlocks().size() > 1) { if (copyOp.getParentRegion()->getBlocks().size() > 1) {
@ -101,5 +101,5 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass() {
static PassRegistration<LhloCopyRemoval> copy_removal_pass( static PassRegistration<LhloCopyRemoval> copy_removal_pass(
"lhlo-copy-removal", "Removes redundant LHLO copy operations"); "lhlo-copy-removal", "Removes redundant LHLO copy operations");
} // namespace xla_lhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir { namespace mlir {
namespace xla_lhlo { namespace lmhlo {
namespace { namespace {
using linalg::LinalgOp; using linalg::LinalgOp;
@ -147,5 +147,5 @@ static PassRegistration<LhloFuseLinalg> legalize_pass(
"lhlo-fuse-linalg", "lhlo-fuse-linalg",
"Greedily fuse linalg ops obtained after LHLO lowering."); "Greedily fuse linalg ops obtained after LHLO lowering.");
} // namespace xla_lhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
namespace mlir { namespace mlir {
namespace xla_lhlo { namespace lmhlo {
namespace { namespace {
// Builds an affine loop nest iterating from zeros to "upper_bounds" with unit // Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices); auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
auto result = auto result =
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices); rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<DotOp>( Value op_result = lmhlo::XlaOpToStdScalarOp::map<DotOp>(
op, element_type, {l, r, result}, &builder); op, element_type, {l, r, result}, &builder);
map_status = success(op_result != nullptr); map_status = success(op_result != nullptr);
if (failed(map_status)) return; if (failed(map_status)) return;
@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
ValueRange induction_vars) { ValueRange induction_vars) {
auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars); auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars); auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>( Value op_result = lmhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
op, element_type, {l, r}, &builder); op, element_type, {l, r}, &builder);
map_status = success(op_result != nullptr); map_status = success(op_result != nullptr);
if (failed(map_status)) return; if (failed(map_status)) return;
@ -127,13 +127,13 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
// clang-format off // clang-format off
patterns->insert< patterns->insert<
BinaryOpConverter<xla_lhlo::AddOp>, BinaryOpConverter<lmhlo::AddOp>,
BinaryOpConverter<xla_lhlo::AndOp>, BinaryOpConverter<lmhlo::AndOp>,
BinaryOpConverter<xla_lhlo::DivOp>, BinaryOpConverter<lmhlo::DivOp>,
BinaryOpConverter<xla_lhlo::MaxOp>, BinaryOpConverter<lmhlo::MaxOp>,
BinaryOpConverter<xla_lhlo::MinOp>, BinaryOpConverter<lmhlo::MinOp>,
BinaryOpConverter<xla_lhlo::MulOp>, BinaryOpConverter<lmhlo::MulOp>,
BinaryOpConverter<xla_lhlo::SubOp>, BinaryOpConverter<lmhlo::SubOp>,
DotOpConverter>(context); DotOpConverter>(context);
// clang-format on // clang-format on
} }
@ -157,5 +157,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() {
static PassRegistration<LhloLegalizeToAffine> legalize_pass( static PassRegistration<LhloLegalizeToAffine> legalize_pass(
"lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect"); "lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect");
} // namespace xla_lhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -38,7 +38,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
namespace mlir { namespace mlir {
namespace xla_lhlo { namespace lmhlo {
namespace { namespace {
// A simple translation of LHLO reduce operations to a corresponding gpu // A simple translation of LHLO reduce operations to a corresponding gpu
@ -173,7 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect, target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>(); gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
target.addIllegalOp<ReduceOp>(); target.addIllegalOp<ReduceOp>();
auto func = getFunction(); auto func = getFunction();
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext()); patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
@ -192,5 +192,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() {
static PassRegistration<LhloLegalizeToGpu> legalize_pass( static PassRegistration<LhloLegalizeToGpu> legalize_pass(
"lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect"); "lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect");
} // namespace xla_lhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir { namespace mlir {
namespace xla_lhlo { namespace lmhlo {
namespace { namespace {
struct StaticMemRefCastOpConverter struct StaticMemRefCastOpConverter
@ -132,5 +132,5 @@ void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
*converter, options); *converter, options);
} }
} // namespace xla_lhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir { namespace mlir {
namespace xla_lhlo { namespace lmhlo {
namespace { namespace {
class TestLhloToLLVMPass class TestLhloToLLVMPass
@ -42,7 +42,7 @@ class TestLhloToLLVMPass
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>(); target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalDialect<XlaLhloDialect>(); target.addIllegalDialect<LmhloDialect>();
if (failed(applyFullConversion(m, target, patterns))) { if (failed(applyFullConversion(m, target, patterns))) {
signalPassFailure(); signalPassFailure();
@ -55,5 +55,5 @@ class TestLhloToLLVMPass
static PassRegistration<TestLhloToLLVMPass> legalize_lhlo_pass( static PassRegistration<TestLhloToLLVMPass> legalize_lhlo_pass(
"test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM."); "test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM.");
} // namespace xla_lhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir { namespace mlir {
namespace xla_lhlo { namespace lmhlo {
namespace { namespace {
// Clones and adapts the code in `lhlo_block` that works on buffers and has a // Clones and adapts the code in `lhlo_block` that works on buffers and has a
@ -154,14 +154,14 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
return b->create<scf::ParallelOp>(loc, lower, upper, step); return b->create<scf::ParallelOp>(loc, lower, upper, step);
} }
// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp. // Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops if there are // The outper `ParallelOp` refers to the parallel loops if there are
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
// contains the reduction operator. // contains the reduction operator.
// //
// Example: // Example:
// //
// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( { // "lmhlo.reduce"(%buffer, %init_buf, %result) ( {
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>): // ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// <LHLO ops> // <LHLO ops>
// } ) {dimensions = dense<[1]> : tensor<1xi64>} // } ) {dimensions = dense<[1]> : tensor<1xi64>}
@ -187,12 +187,12 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
// } : f32 // } : f32
// scf.yield // scf.yield
// } // }
class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> { class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
public: public:
using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern; using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/, lmhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
// TODO(b/137624192) Implement variadic reduce. // TODO(b/137624192) Implement variadic reduce.
if (xla_reduce_op.out().size() != 1) return failure(); if (xla_reduce_op.out().size() != 1) return failure();
@ -226,7 +226,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// scf.yield // scf.yield
// } // }
scf::ReduceOp CreateReduceOpInNestedParallelLoops( scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceOp xla_reduce_op, lmhlo::ReduceOp xla_reduce_op,
ConversionPatternRewriter* rewriter) const { ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_op.getLoc(); auto loc = xla_reduce_op.getLoc();
DenseSet<int> reducing_dims; DenseSet<int> reducing_dims;
@ -314,7 +314,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// accumulator = reduction_operator(output[O], value) // accumulator = reduction_operator(output[O], value)
// output[O] = accumulator // output[O] = accumulator
// //
// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a // Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
// scf::ReduceOp. // scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops that traverese output // The outper `ParallelOp` refers to the parallel loops that traverese output
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse // buffer. The inner `ParalleOp` refers to the reduction loops that traverse
@ -325,11 +325,11 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// func @reduce_window(%arg: memref<112x112xf32>, // func @reduce_window(%arg: memref<112x112xf32>,
// %init: memref<f32>, // %init: memref<f32>,
// %result: memref<56x56xf32>) { // %result: memref<56x56xf32>) {
// "xla_lhlo.reduce_window"(%arg, %init, %result) ( { // "lmhlo.reduce_window"(%arg, %init, %result) ( {
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>): // ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// "xla_lhlo.maximum"(%lhs, %rhs, %res) // "lmhlo.maximum"(%lhs, %rhs, %res)
// : (memref<f32>, memref<f32>, memref<f32>) -> () // : (memref<f32>, memref<f32>, memref<f32>) -> ()
// "xla_lhlo.terminator"() : () -> () // "lmhlo.terminator"() : () -> ()
// }) { // }) {
// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, // padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
// window_dimensions = dense<[3, 3]> : tensor<2xi64>, // window_dimensions = dense<[3, 3]> : tensor<2xi64>,
@ -359,12 +359,12 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// return // return
// } // }
class ReduceWindowOpConverter class ReduceWindowOpConverter
: public OpConversionPattern<xla_lhlo::ReduceWindowOp> { : public OpConversionPattern<lmhlo::ReduceWindowOp> {
public: public:
using OpConversionPattern<xla_lhlo::ReduceWindowOp>::OpConversionPattern; using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/, lmhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
scf::ParallelOp output_loop, window_loop; scf::ParallelOp output_loop, window_loop;
std::tie(output_loop, window_loop) = std::tie(output_loop, window_loop) =
@ -383,7 +383,7 @@ class ReduceWindowOpConverter
private: private:
std::pair<scf::ParallelOp, scf::ParallelOp> std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow( CreateParallelLoopsToTraverseOutputAndWindow(
xla_lhlo::ReduceWindowOp xla_reduce_window_op, lmhlo::ReduceWindowOp xla_reduce_window_op,
ConversionPatternRewriter* rewriter) const { ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_window_op.getLoc(); auto loc = xla_reduce_window_op.getLoc();
Value init_value = Value init_value =
@ -415,9 +415,8 @@ class ReduceWindowOpConverter
} }
scf::ReduceOp CreateReduceOpInNestedParallelLoops( scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceWindowOp xla_reduce_window_op, lmhlo::ReduceWindowOp xla_reduce_window_op, scf::ParallelOp output_loop,
scf::ParallelOp output_loop, scf::ParallelOp window_loop, scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
ConversionPatternRewriter* rewriter) const {
rewriter->setInsertionPointToStart(window_loop.getBody()); rewriter->setInsertionPointToStart(window_loop.getBody());
auto loc = xla_reduce_window_op.getLoc(); auto loc = xla_reduce_window_op.getLoc();
@ -481,12 +480,12 @@ class ReduceWindowOpConverter
// initialized_flag = true // initialized_flag = true
// output(selected_index) = scatter(output(selected_index), source(S)) // output(selected_index) = scatter(output(selected_index), source(S))
class SelectAndScatterOpConverter class SelectAndScatterOpConverter
: public OpConversionPattern<xla_lhlo::SelectAndScatterOp> { : public OpConversionPattern<lmhlo::SelectAndScatterOp> {
public: public:
using OpConversionPattern<xla_lhlo::SelectAndScatterOp>::OpConversionPattern; using OpConversionPattern<lmhlo::SelectAndScatterOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/, lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = s_and_s_op.getLoc(); auto loc = s_and_s_op.getLoc();
InitializeOutput(s_and_s_op, &rewriter); InitializeOutput(s_and_s_op, &rewriter);
@ -515,7 +514,7 @@ class SelectAndScatterOpConverter
} }
private: private:
void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op, void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
OpBuilder* b) const { OpBuilder* b) const {
auto loc = s_and_s_op.getLoc(); auto loc = s_and_s_op.getLoc();
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value()); Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
@ -533,7 +532,7 @@ class SelectAndScatterOpConverter
SmallVector<Value, 2> window_ivs; SmallVector<Value, 2> window_ivs;
scf::ForOp inner_loop; scf::ForOp inner_loop;
}; };
WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op, WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,
scf::ParallelOp loop_over_src, scf::ParallelOp loop_over_src,
OpBuilder* b) const { OpBuilder* b) const {
auto loc = s_and_s_op.getLoc(); auto loc = s_and_s_op.getLoc();
@ -598,7 +597,7 @@ class SelectAndScatterOpConverter
SmallVector<Value, 4> ivs_val_flag_; SmallVector<Value, 4> ivs_val_flag_;
}; };
SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op, SmallVector<Value, 2> SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,
scf::ParallelOp loop_over_src, scf::ParallelOp loop_over_src,
OpBuilder* b) const { OpBuilder* b) const {
auto loc = s_and_s_op.getLoc(); auto loc = s_and_s_op.getLoc();
@ -636,9 +635,10 @@ class SelectAndScatterOpConverter
return window_loops.selected_ivs; return window_loops.selected_ivs;
} }
SmallVector<Value, 4> SelectOrInitialize( SmallVector<Value, 4> SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> operand_ivs, ArrayRef<Value> operand_ivs,
IterArgs* ivs_val_flag, OpBuilder* b) const { IterArgs* ivs_val_flag,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc(); auto loc = s_and_s_op.getLoc();
Value true_i1 = b->create<mlir::ConstantOp>( Value true_i1 = b->create<mlir::ConstantOp>(
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1)); loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
@ -707,9 +707,9 @@ struct LhloLegalizeToParallelLoops
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect, target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
scf::SCFDialect, XlaLhloDialect>(); scf::SCFDialect, LmhloDialect>();
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp, target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
xla_lhlo::SelectAndScatterOp>(); lmhlo::SelectAndScatterOp>();
if (failed(applyPartialConversion(func, target, patterns))) { if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure(); signalPassFailure();
@ -727,5 +727,5 @@ static PassRegistration<LhloLegalizeToParallelLoops> legalize_lhlo_pass(
"lhlo-legalize-to-parallel-loops", "lhlo-legalize-to-parallel-loops",
"Legalize from LHLO dialect to parallel loops."); "Legalize from LHLO dialect to parallel loops.");
} // namespace xla_lhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -131,9 +131,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
loc, opResultTypes, args, args_count, results_count, indexing_maps, loc, opResultTypes, args, args_count, results_count, indexing_maps,
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace. // TODO(ravishankarm) : For now use the method in lmhlo namespace.
// That method needs to be moved out of there. // That method needs to be moved out of there.
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>( Value opResult = lmhlo::XlaOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes, op, bodyResultTypes,
llvm::to_vector<2>(args.take_front(args_count)), &rewriter); llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, opResult); nestedBuilder.create<linalg::YieldOp>(loc, opResult);
@ -162,8 +162,8 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
// Create two loads from the input. // Create two loads from the input.
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs()); auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs()); auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace. // TODO(ravishankarm) : Move this method out of lmhlo namespace.
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>( Value opResult = lmhlo::XlaOpToStdScalarOp::map<LhloOp>(
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs}, lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter); &rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out()); rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
@ -173,21 +173,21 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// xla_lhlo.convolution conversion pattern. // lmhlo.convolution conversion pattern.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Converts xla_lhlo.convolution operation to a linalg.conv op. /// Converts lmhlo.convolution operation to a linalg.conv op.
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> { struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
public: public:
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern; using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
// This code has been adapted from IREE's // This code has been adapted from IREE's
// (https://github.com/google/iree/) mhlo -> linalg conversion. // (https://github.com/google/iree/) mhlo -> linalg conversion.
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_lhlo::ConvOp op, ArrayRef<Value> args, lmhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information. // Check validity of dimension information.
if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers = if (const lmhlo::ConvDimensionNumbers& dimensionNumbers =
op.dimension_numbers()) { op.dimension_numbers()) {
const int inputSpatialRank = const int inputSpatialRank =
llvm::size(dimensionNumbers.input_spatial_dimensions()); llvm::size(dimensionNumbers.input_spatial_dimensions());
@ -388,14 +388,14 @@ class HloBroadcastInDimConverter
}; };
class LhloBroadcastInDimConverter class LhloBroadcastInDimConverter
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> { : public OpConversionPattern<lmhlo::BroadcastInDimOp> {
public: public:
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern; using OpConversionPattern<lmhlo::BroadcastInDimOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args, lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
auto result_type = operand_adaptor.output().getType().cast<MemRefType>(); auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
auto result_shape = result_type.getShape(); auto result_shape = result_type.getShape();
@ -444,9 +444,9 @@ class LhloBroadcastInDimConverter
// Inserts 'linalg.reshape' if there is a size-1 dim expansion. // Inserts 'linalg.reshape' if there is a size-1 dim expansion.
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary( std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args, lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const { ConversionPatternRewriter& rewriter) const {
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
Value operand = operand_adaptor.operand(); Value operand = operand_adaptor.operand();
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>(); auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape(); auto operand_shape = operand_type.getShape();
@ -512,7 +512,7 @@ class LhloBroadcastInDimConverter
return std::make_pair(operand, broadcast_dims); return std::make_pair(operand, broadcast_dims);
} }
SmallVector<AffineMap, 2> getIndexingMaps(xla_lhlo::BroadcastInDimOp op, SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
ArrayRef<int64_t> broadcastDims, ArrayRef<int64_t> broadcastDims,
ArrayRef<int64_t> resultShape, ArrayRef<int64_t> resultShape,
MemRefType operandType, MemRefType operandType,
@ -639,12 +639,12 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
} }
}; };
class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> { class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
public: public:
using OpConversionPattern<xla_lhlo::IotaOp>::OpConversionPattern; using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_lhlo::IotaOp iotaOp, ArrayRef<Value> args, lmhlo::IotaOp iotaOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto resultMemrefType = auto resultMemrefType =
iotaOp.getOperand().getType().dyn_cast<MemRefType>(); iotaOp.getOperand().getType().dyn_cast<MemRefType>();
@ -680,12 +680,12 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
} }
}; };
class ConstConverter : public OpConversionPattern<xla_lhlo::ConstOp> { class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
public: public:
using OpConversionPattern<xla_lhlo::ConstOp>::OpConversionPattern; using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_lhlo::ConstOp constOp, ArrayRef<Value> args, lmhlo::ConstOp constOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = constOp.getLoc(); auto loc = constOp.getLoc();
auto valueAttr = constOp.value().cast<DenseElementsAttr>(); auto valueAttr = constOp.value().cast<DenseElementsAttr>();
@ -726,12 +726,12 @@ class ReverseConverter
} }
}; };
class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> { class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
public: public:
using OpConversionPattern<xla_lhlo::SliceOp>::OpConversionPattern; using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_lhlo::SliceOp sliceOp, ArrayRef<Value> args, lmhlo::SliceOp sliceOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = sliceOp.getLoc(); auto loc = sliceOp.getLoc();
auto argType = auto argType =
@ -763,50 +763,50 @@ class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
void populateLHLOToLinalgConversionPattern(MLIRContext* context, void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
// clang-format off // clang-format off
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>, patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
ConstConverter, ConstConverter,
ConvToLinalgConverter, ConvToLinalgConverter,
IotaConverter, IotaConverter,
LhloBroadcastInDimConverter, LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_lhlo::AbsOp>, PointwiseToLinalgConverter<lmhlo::AbsOp>,
PointwiseToLinalgConverter<xla_lhlo::AddOp>, PointwiseToLinalgConverter<lmhlo::AddOp>,
PointwiseToLinalgConverter<xla_lhlo::AndOp>, PointwiseToLinalgConverter<lmhlo::AndOp>,
PointwiseToLinalgConverter<xla_lhlo::CeilOp>, PointwiseToLinalgConverter<lmhlo::CeilOp>,
PointwiseToLinalgConverter<xla_lhlo::CompareOp>, PointwiseToLinalgConverter<lmhlo::CompareOp>,
PointwiseToLinalgConverter<xla_lhlo::ComplexOp>, PointwiseToLinalgConverter<lmhlo::ComplexOp>,
PointwiseToLinalgConverter<xla_lhlo::ConvertOp>, PointwiseToLinalgConverter<lmhlo::ConvertOp>,
// TODO(ataei): Remove this pattern, CopyOp is folded away. // TODO(ataei): Remove this pattern, CopyOp is folded away.
PointwiseToLinalgConverter<xla_lhlo::CopyOp>, PointwiseToLinalgConverter<lmhlo::CopyOp>,
PointwiseToLinalgConverter<xla_lhlo::CosOp>, PointwiseToLinalgConverter<lmhlo::CosOp>,
PointwiseToLinalgConverter<xla_lhlo::DivOp>, PointwiseToLinalgConverter<lmhlo::DivOp>,
PointwiseToLinalgConverter<xla_lhlo::ExpOp>, PointwiseToLinalgConverter<lmhlo::ExpOp>,
PointwiseToLinalgConverter<xla_lhlo::ImagOp>, PointwiseToLinalgConverter<lmhlo::ImagOp>,
PointwiseToLinalgConverter<xla_lhlo::LogOp>, PointwiseToLinalgConverter<lmhlo::LogOp>,
PointwiseToLinalgConverter<xla_lhlo::MaxOp>, PointwiseToLinalgConverter<lmhlo::MaxOp>,
PointwiseToLinalgConverter<xla_lhlo::MinOp>, PointwiseToLinalgConverter<lmhlo::MinOp>,
PointwiseToLinalgConverter<xla_lhlo::MulOp>, PointwiseToLinalgConverter<lmhlo::MulOp>,
PointwiseToLinalgConverter<xla_lhlo::NegOp>, PointwiseToLinalgConverter<lmhlo::NegOp>,
PointwiseToLinalgConverter<xla_lhlo::RealOp>, PointwiseToLinalgConverter<lmhlo::RealOp>,
PointwiseToLinalgConverter<xla_lhlo::RemOp>, PointwiseToLinalgConverter<lmhlo::RemOp>,
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>, PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SelectOp>, PointwiseToLinalgConverter<lmhlo::SelectOp>,
PointwiseToLinalgConverter<xla_lhlo::SignOp>, PointwiseToLinalgConverter<lmhlo::SignOp>,
PointwiseToLinalgConverter<xla_lhlo::SinOp>, PointwiseToLinalgConverter<lmhlo::SinOp>,
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>, PointwiseToLinalgConverter<lmhlo::SqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SubOp>, PointwiseToLinalgConverter<lmhlo::SubOp>,
PointwiseToLinalgConverter<xla_lhlo::TanhOp>, PointwiseToLinalgConverter<lmhlo::TanhOp>,
ReshapeOpConverter<xla_lhlo::ReshapeOp>, ReshapeOpConverter<lmhlo::ReshapeOp>,
ReverseConverter<xla_lhlo::ReverseOp>, ReverseConverter<lmhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>, ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
SliceConverter SliceConverter
>(context); >(context);
// clang-format on // clang-format on
} }
// Converts LHLO ops to Linalg generic. // Converts LHLO ops to Linalg generic.
// Sample result for xla_lhlo::AddOp. // Sample result for lmhlo::AddOp.
// //
// "xla_lhlo.add"(%arg1, %arg2, %out) : // "lmhlo.add"(%arg1, %arg2, %out) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// //
// will be converted to // will be converted to
@ -854,14 +854,14 @@ struct HloLegalizeToLinalg
} // namespace } // namespace
namespace xla_lhlo { namespace lmhlo {
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() { std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
return absl::make_unique<LhloLegalizeToLinalg>(); return absl::make_unique<LhloLegalizeToLinalg>();
} }
static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass( static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect"); "lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
} // namespace xla_lhlo } // namespace lmhlo
namespace mhlo { namespace mhlo {

View File

@ -7,7 +7,7 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_result = "mhlo.exponential"(%tensor_operand) %tensor_result = "mhlo.exponential"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} // BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -18,10 +18,10 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
return %arg0 : tensor<4xf32> return %arg0 : tensor<4xf32>
} }
// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) // PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () // PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
// PRE-NEXT: return // PRE-NEXT: return
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// ESC-NOT: "xla_lhlo.copy" // ESC-NOT: "lmhlo.copy"
// ESC-NEXT: return %[[ARG0]] // ESC-NEXT: return %[[ARG0]]
// ----- // -----
@ -38,20 +38,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) // PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32> // ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> // BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) // BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> // BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) // BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> // BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> // BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) // BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> // BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
//  BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) //  BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> // BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) // BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> // BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () // PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> // PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
// PRE-NEXT: return // PRE-NEXT: return
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32> // ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
@ -67,14 +67,14 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2) %sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) // BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier) %tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) // BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) // BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
// BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> // BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: return // BOTH-NEXT: return
@ -88,7 +88,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.copy"(%tensor_operand) %tensor_result = "mhlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -100,7 +100,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.exponential"(%tensor_operand) %tensor_result = "mhlo.exponential"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -112,7 +112,7 @@ func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.log"(%tensor_operand) %tensor_result = "mhlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -127,7 +127,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) %tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) // BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -141,7 +141,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs) %tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"} {comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} // BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
tensor_store %tensor_result, %result : memref<2x2xi1> tensor_store %tensor_result, %result : memref<2x2xi1>
return return
} }
@ -154,7 +154,7 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand) %tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>} {broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32> : (tensor<5xf32>) -> tensor<10x5xf32>
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} // BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<10x5xf32> tensor_store %tensor_result, %result : memref<10x5xf32>
return return
} }
@ -205,12 +205,12 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] // BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index // BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast // BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast
// BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]) // BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] // BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0> // BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
// BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { // BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> // BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> () // BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
@ -229,7 +229,7 @@ func @complex(%real: memref<2x2xf32>,
%tensor_imag = tensor_load %imag : memref<2x2xf32> %tensor_imag = tensor_load %imag : memref<2x2xf32>
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag) %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>> tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
return return
} }
@ -241,7 +241,7 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>> %tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "mhlo.real"(%tensor_operand) %tensor_result = "mhlo.real"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -253,7 +253,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>> %tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "mhlo.imag"(%tensor_operand) %tensor_result = "mhlo.imag"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -264,7 +264,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
func @iota(%result: memref<10xi32>) { func @iota(%result: memref<10xi32>) {
%tensor_result = "mhlo.iota"() %tensor_result = "mhlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32> {iota_dimension = 0 : i64} : () -> tensor<10xi32>
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} // BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32> tensor_store %tensor_result, %result : memref<10xi32>
return return
} }
@ -276,7 +276,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.abs"(%tensor_operand) %tensor_result = "mhlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -288,7 +288,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.ceil"(%tensor_operand) %tensor_result = "mhlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -300,7 +300,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.convert"(%tensor_operand) %tensor_result = "mhlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
// BOTH-NOT: tensor_store // BOTH-NOT: tensor_store
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
@ -313,7 +313,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.cosine"(%tensor_operand) %tensor_result = "mhlo.cosine"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -325,7 +325,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.negate"(%tensor_operand) %tensor_result = "mhlo.negate"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -337,7 +337,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.rsqrt"(%tensor_operand) %tensor_result = "mhlo.rsqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -349,7 +349,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.sign"(%tensor_operand) %tensor_result = "mhlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -361,7 +361,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.sqrt"(%tensor_operand) %tensor_result = "mhlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -373,7 +373,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.tanh"(%tensor_operand) %tensor_result = "mhlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) // BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -386,7 +386,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
%tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs) %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) // BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
@ -412,7 +412,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () // BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return return
} }
@ -437,7 +437,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () // BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return return
} }
@ -448,10 +448,10 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) // PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc // BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () // BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "mhlo.dot"(%arg0, %arg0) %dot = "mhlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]]) // PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
// ESC: return %[[ALLOC]] // ESC: return %[[ALLOC]]
return %dot : tensor<1024x1024xf32> return %dot : tensor<1024x1024xf32>
} }
@ -462,7 +462,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> { func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index %c0 = constant 0 : index
// BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> // BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) // BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// BOTH-SAME: padding = dense<[ // BOTH-SAME: padding = dense<[
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> // BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// BOTH-SAME: rhs_dilation = dense<[1, 2]> // BOTH-SAME: rhs_dilation = dense<[1, 2]>

View File

@ -3,10 +3,10 @@
// CHECK-LABEL: func @remove_simple // CHECK-LABEL: func @remove_simple
func @remove_simple(%arg0: memref<2x2xf32>) { func @remove_simple(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32> %0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32> dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () // CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }
// ----- // -----
@ -14,9 +14,9 @@ func @remove_simple(%arg0: memref<2x2xf32>) {
// CHECK-LABEL: func @remove_without_dealloc // CHECK-LABEL: func @remove_without_dealloc
func @remove_without_dealloc(%arg0: memref<2x2xf32>) { func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32> %0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () // CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }
// ----- // -----
@ -24,22 +24,22 @@ func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
// CHECK-LABEL: func @replace_dependency // CHECK-LABEL: func @replace_dependency
func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32> %0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32> dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () // CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }
// ----- // -----
// CHECK-LABEL: func @keep_copies // CHECK-LABEL: func @keep_copies
func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
// CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () // CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () // CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }
// ----- // -----
@ -50,14 +50,14 @@ func @must_not_be_removed(%arg0: memref<2x2xf32>,
%arg2: memref<2x2xf32>) { %arg2: memref<2x2xf32>) {
// CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32> // CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32>
%0 = alloc() {temp = true} : memref<2x2xf32> %0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () // CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () // CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32> dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }
// ----- // -----
@ -67,13 +67,13 @@ func @must_be_removed_first(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>, %arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) { %arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32> %0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32> dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }
// ----- // -----
@ -83,11 +83,11 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>, %arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) { %arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32> %0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32> dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }

View File

@ -10,18 +10,18 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
%src: memref<56x56xf32>, %src: memref<56x56xf32>,
%init: memref<f32>, %init: memref<f32>,
%result: memref<112x112xf32>) { %result: memref<112x112xf32>) {
"xla_lhlo.select_and_scatter"(%arg, %src, %init, %result) ( { "lmhlo.select_and_scatter"(%arg, %src, %init, %result) ( {
// select // select
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %pred: memref<i1>): ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %pred: memref<i1>):
"xla_lhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} : "lmhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} :
(memref<f32>, memref<f32>, memref<i1>) -> () (memref<f32>, memref<f32>, memref<i1>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
}, { }, {
// scatter // scatter
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %out: memref<f32>): ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %out: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %out) : "lmhlo.add"(%lhs, %rhs, %out) :
(memref<f32>, memref<f32>, memref<f32>) -> () (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
}) { }) {
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
window_dimensions = dense<[3, 3]> : tensor<2xi64>, window_dimensions = dense<[3, 3]> : tensor<2xi64>,
@ -29,7 +29,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
} : (memref<112x112xf32>, } : (memref<112x112xf32>,
memref<56x56xf32>, memref<56x56xf32>,
memref<f32>, memref<112x112xf32>) -> () memref<f32>, memref<112x112xf32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }
// CHECK-LABEL: func @select_and_scatter( // CHECK-LABEL: func @select_and_scatter(
// CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>, // CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>,
@ -121,7 +121,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32> // CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
// Compute PRED. // Compute PRED.
// CHECK: "xla_lhlo.compare"( // CHECK: "lmhlo.compare"(
// CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]]) // CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
// CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1> // CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1>
@ -182,7 +182,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32> // CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
// Compute scatter value. // Compute scatter value.
// CHECK: "xla_lhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) : // CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> () // CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32> // CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32>

View File

@ -14,7 +14,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
// CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> // CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
// CHECK: return // CHECK: return
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
(memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> () (memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> ()
return return
} }
@ -24,7 +24,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () { %result: memref<7xf32>) -> () {
// CHECK: addf %{{.*}}, %{{.*}} : f32 // CHECK: addf %{{.*}}, %{{.*}} : f32
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"} "lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return return
} }
@ -32,7 +32,7 @@ func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () { %result: memref<7xi32>) -> () {
// CHECK: addi %{{.*}}, %{{.*}} : i32 // CHECK: addi %{{.*}}, %{{.*}} : i32
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"} "lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return return
} }
@ -42,7 +42,7 @@ func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () { %result: memref<7xi32>) -> () {
// CHECK: and %{{.*}}, %{{.*}} : i32 // CHECK: and %{{.*}}, %{{.*}} : i32
"xla_lhlo.and"(%lhs, %rhs, %result) {name = "and.1"} "lmhlo.and"(%lhs, %rhs, %result) {name = "and.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return return
} }
@ -52,7 +52,7 @@ func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () { %result: memref<7xf32>) -> () {
// CHECK: divf %{{.*}}, %{{.*}} : f32 // CHECK: divf %{{.*}}, %{{.*}} : f32
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} "lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return return
} }
@ -60,7 +60,7 @@ func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () { %result: memref<7xi32>) -> () {
// CHECK: divi_signed %{{.*}}, %{{.*}} : i32 // CHECK: divi_signed %{{.*}}, %{{.*}} : i32
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} "lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return return
} }
@ -71,7 +71,7 @@ func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () { %result: memref<7xf32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32 // CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return return
} }
@ -81,7 +81,7 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () { %result: memref<7xi32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32 // CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return return
} }
@ -92,7 +92,7 @@ func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () { %result: memref<7xf32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32 // CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return return
} }
@ -102,7 +102,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () { %result: memref<7xi32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32 // CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return return
} }
@ -112,7 +112,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () { %result: memref<7xf32>) -> () {
// CHECK: mulf %{{.*}}, %{{.*}} : f32 // CHECK: mulf %{{.*}}, %{{.*}} : f32
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} "lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return return
} }
@ -121,7 +121,7 @@ func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () { %result: memref<7xi32>) -> () {
// CHECK: muli %{{.*}}, %{{.*}} : i32 // CHECK: muli %{{.*}}, %{{.*}} : i32
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} "lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return return
} }
@ -131,7 +131,7 @@ func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () { %result: memref<7xf32>) -> () {
// CHECK: subf %{{.*}}, %{{.*}} : f32 // CHECK: subf %{{.*}}, %{{.*}} : f32
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} "lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return return
} }
@ -139,7 +139,7 @@ func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () { %result: memref<7xi32>) -> () {
// CHECK: subi %{{.*}}, %{{.*}} : i32 // CHECK: subi %{{.*}}, %{{.*}} : i32
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} "lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return return
} }
@ -158,7 +158,7 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs:
// CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32 // CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32> // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
// CHECK: return // CHECK: return
"xla_lhlo.dot"(%lhs, %rhs, %result) : "lmhlo.dot"(%lhs, %rhs, %result) :
(memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> () (memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> ()
return return
} }
@ -175,7 +175,7 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
// CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32 // CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32> // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
// CHECK: return // CHECK: return
"xla_lhlo.dot"(%lhs, %rhs, %result) : "lmhlo.dot"(%lhs, %rhs, %result) :
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
return return
} }

View File

@ -3,11 +3,11 @@
func @reduce(%arg: memref<100x10xf32>, func @reduce(%arg: memref<100x10xf32>,
%init: memref<f32>, %init: memref<f32>,
%result: memref<100xf32>) { %result: memref<100xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( { "lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>): ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res) "lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> () : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>} } ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> () : (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
return return
@ -25,7 +25,7 @@ func @reduce(%arg: memref<100x10xf32>,
// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { // CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref<f32, #map0> // CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref<f32, #map0>
// CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref<f32, #map0> // CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref<f32, #map0>
// CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> () // CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
// CHECK: } // CHECK: }
// CHECK: gpu.terminator // CHECK: gpu.terminator
// CHECK: } // CHECK: }

View File

@ -4,7 +4,7 @@
// CHECK-LABEL: func @element_wise // CHECK-LABEL: func @element_wise
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) { %result: memref<2x2xf32>) {
"xla_lhlo.add"(%lhs, %rhs, %result) "lmhlo.add"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
@ -19,7 +19,7 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>, func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>,
%rhs: memref<?x?xf32>, %rhs: memref<?x?xf32>,
%result: memref<?x?xf32>) { %result: memref<?x?xf32>) {
"xla_lhlo.add"(%lhs, %rhs, %result) "lmhlo.add"(%lhs, %rhs, %result)
: (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return return
} }
@ -33,7 +33,7 @@ func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>,
// CHECK-LABEL: func @element_wise_scalar // CHECK-LABEL: func @element_wise_scalar
func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>, func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
%result: memref<f32>) { %result: memref<f32>) {
"xla_lhlo.add"(%lhs, %rhs, %result) "lmhlo.add"(%lhs, %rhs, %result)
: (memref<f32>, memref<f32>, memref<f32>) -> () : (memref<f32>, memref<f32>, memref<f32>) -> ()
return return
} }
@ -48,7 +48,7 @@ func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
// CHECK-LABEL: func @minf // CHECK-LABEL: func @minf
func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) { %result: memref<2x2xf32>) {
"xla_lhlo.minimum"(%lhs, %rhs, %result) "lmhlo.minimum"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
@ -63,7 +63,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
// CHECK-LABEL: func @maxi // CHECK-LABEL: func @maxi
func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) { %result: memref<2x2xi32>) {
"xla_lhlo.maximum"(%lhs, %rhs, %result) "lmhlo.maximum"(%lhs, %rhs, %result)
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
return return
} }
@ -78,7 +78,7 @@ func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
// CHECK-LABEL: func @and // CHECK-LABEL: func @and
func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) { %result: memref<2x2xi32>) {
"xla_lhlo.and"(%lhs, %rhs, %result) "lmhlo.and"(%lhs, %rhs, %result)
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
return return
} }
@ -91,7 +91,7 @@ func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
// CHECK-LABEL: func @exp // CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.exponential"(%input, %result) "lmhlo.exponential"(%input, %result)
: (memref<2x2xf32>, memref<2x2xf32>) -> () : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
@ -104,7 +104,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @log // CHECK-LABEL: func @log
func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -116,7 +116,7 @@ func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @copy // CHECK-LABEL: func @copy
func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
"xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () "lmhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -128,7 +128,7 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
// CHECK-LABEL: func @float_cmp // CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xi1>) { %result: memref<2x2xi1>) {
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"} "lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> () : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> ()
return return
} }
@ -142,7 +142,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
// CHECK-LABEL: func @int_cmp // CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi1>) { %result: memref<2x2xi1>) {
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} "lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> () : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> ()
return return
} }
@ -156,7 +156,7 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
// CHECK-LABEL: func @select // CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.select"(%pred, %lhs, %rhs, %result) "lmhlo.select"(%pred, %lhs, %rhs, %result)
: (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () : (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
@ -170,7 +170,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @iota // CHECK-LABEL: func @iota
func @iota(%out: memref<7x10xf32>) { func @iota(%out: memref<7x10xf32>) {
"xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> () "lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
return return
} }
// CHECK: linalg.indexed_generic // CHECK: linalg.indexed_generic
@ -186,7 +186,7 @@ func @iota(%out: memref<7x10xf32>) {
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar // CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) { func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
"xla_lhlo.broadcast"(%operand, %result) { "lmhlo.broadcast"(%operand, %result) {
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
} : (memref<f32>, memref<4x2x1xf32>) -> () } : (memref<f32>, memref<4x2x1xf32>) -> ()
return return
@ -203,7 +203,7 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
// CHECK-LABEL: func @broadcast // CHECK-LABEL: func @broadcast
func @broadcast(%operand: memref<4x?x16xf32>, func @broadcast(%operand: memref<4x?x16xf32>,
%result: memref<4x2x1x4x?x16xf32>) { %result: memref<4x2x1x4x?x16xf32>) {
"xla_lhlo.broadcast"(%operand, %result) { "lmhlo.broadcast"(%operand, %result) {
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
} : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> () } : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> ()
return return
@ -220,7 +220,7 @@ func @broadcast(%operand: memref<4x?x16xf32>,
// CHECK-LABEL: func @dynamic_broadcast_in_dim // CHECK-LABEL: func @dynamic_broadcast_in_dim
func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>, func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
%result: memref<?x?x?x?x?xf32>) { %result: memref<?x?x?x?x?xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) { "lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64> broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>
} : (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> () } : (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
return return
@ -237,7 +237,7 @@ func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_no_expansion // CHECK-LABEL: func @static_broadcast_in_dim_no_expansion
func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>, func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
%result: memref<5x10xf32>) { %result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) { "lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[0]> : tensor<1xi64> broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<5xf32>, memref<5x10xf32>) -> () } : (memref<5xf32>, memref<5x10xf32>) -> ()
return return
@ -255,7 +255,7 @@ func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_expansion // CHECK-LABEL: func @static_broadcast_in_dim_expansion
func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>, func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
%result: memref<5x10x100xf32>) { %result: memref<5x10x100xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) { "lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[2, 0]> : tensor<2xi64> broadcast_dimensions = dense<[2, 0]> : tensor<2xi64>
} : (memref<1x5xf32>, memref<5x10x100xf32>) -> () } : (memref<1x5xf32>, memref<5x10x100xf32>) -> ()
return return
@ -274,7 +274,7 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_scalar // CHECK-LABEL: func @static_broadcast_in_dim_scalar
func @static_broadcast_in_dim_scalar(%operand: memref<f32>, func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
%result: memref<5x10xf32>) { %result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) { "lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[]> : tensor<0xi64> broadcast_dimensions = dense<[]> : tensor<0xi64>
} : (memref<f32>, memref<5x10xf32>) -> () } : (memref<f32>, memref<5x10xf32>) -> ()
return return
@ -291,7 +291,7 @@ func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one // CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one
func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
%result: memref<1x5xf32>) { %result: memref<1x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) { "lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[0]> : tensor<1xi64> broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<1xf32>, memref<1x5xf32>) -> () } : (memref<1xf32>, memref<1x5xf32>) -> ()
return return
@ -307,7 +307,7 @@ func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many // CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many
func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>, func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
%result: memref<5x5xf32>) { %result: memref<5x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) { "lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[1]> : tensor<1xi64> broadcast_dimensions = dense<[1]> : tensor<1xi64>
} : (memref<1xf32>, memref<5x5xf32>) -> () } : (memref<1xf32>, memref<5x5xf32>) -> ()
return return
@ -323,7 +323,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
// CHECK-LABEL: func @constant // CHECK-LABEL: func @constant
func @constant(%value: memref<i32>) { func @constant(%value: memref<i32>) {
"xla_lhlo.constant"(%value) { "lmhlo.constant"(%value) {
value = dense<10> : tensor<i32> value = dense<10> : tensor<i32>
} : (memref<i32>) -> () } : (memref<i32>) -> ()
return return
@ -335,7 +335,7 @@ func @constant(%value: memref<i32>) {
// CHECK-LABEL: func @absf // CHECK-LABEL: func @absf
func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -348,7 +348,7 @@ func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @absi // CHECK-LABEL: func @absi
func @absi(%input: memref<2x2xi32>, func @absi(%input: memref<2x2xi32>,
%result: memref<2x2xi32>) { %result: memref<2x2xi32>) {
"xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () "lmhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return return
} }
@ -364,7 +364,7 @@ func @absi(%input: memref<2x2xi32>,
// CHECK-LABEL: func @ceil // CHECK-LABEL: func @ceil
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -376,7 +376,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_i32_to_f32 // CHECK-LABEL: func @convert_i32_to_f32
func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> () "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -389,7 +389,7 @@ func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_i16_to_i32 // CHECK-LABEL: func @convert_i16_to_i32
func @convert_i16_to_i32(%input: memref<2x2xi16>, func @convert_i16_to_i32(%input: memref<2x2xi16>,
%result: memref<2x2xi32>) { %result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> () "lmhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -401,7 +401,7 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>,
// CHECK-LABEL: func @convert_i32_to_i16 // CHECK-LABEL: func @convert_i32_to_i16
func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) { func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> () "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -413,7 +413,7 @@ func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) {
// CHECK-LABEL: func @convert_f32_to_f64 // CHECK-LABEL: func @convert_f32_to_f64
func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) { func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> () "lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -425,7 +425,7 @@ func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) {
// CHECK-LABEL: func @convert_f64_to_f32 // CHECK-LABEL: func @convert_f64_to_f32
func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) { func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> () "lmhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -437,7 +437,7 @@ func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_i32_to_i32 // CHECK-LABEL: func @convert_i32_to_i32
func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -448,7 +448,7 @@ func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @convert_f32_to_f32 // CHECK-LABEL: func @convert_f32_to_f32
func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -459,7 +459,7 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_f32_to_i32 // CHECK-LABEL: func @convert_f32_to_i32
func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result) "lmhlo.convert"(%input, %result)
: (memref<2x2xf32>, memref<2x2xi32>) -> () : (memref<2x2xf32>, memref<2x2xi32>) -> ()
return return
} }
@ -472,7 +472,7 @@ func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @cos // CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -485,7 +485,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sin // CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xf32>, func @sin(%input: memref<2x2xf32>,
%result: memref<2x2xf32>) { %result: memref<2x2xf32>) {
"xla_lhlo.sine"(%input, %result) "lmhlo.sine"(%input, %result)
: (memref<2x2xf32>, memref<2x2xf32>) -> () : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
@ -498,7 +498,7 @@ func @sin(%input: memref<2x2xf32>,
// CHECK-LABEL: func @negf // CHECK-LABEL: func @negf
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -510,7 +510,7 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @negi // CHECK-LABEL: func @negi
func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
"xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () "lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -524,7 +524,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @rem // CHECK-LABEL: func @rem
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) { %result: memref<2x2xf32>) {
"xla_lhlo.remainder"(%lhs, %rhs, %result) "lmhlo.remainder"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
@ -537,7 +537,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
// CHECK-LABEL: func @rsqrt // CHECK-LABEL: func @rsqrt
func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -549,7 +549,7 @@ func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sign // CHECK-LABEL: func @sign
func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -562,7 +562,7 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sqrt // CHECK-LABEL: func @sqrt
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -574,7 +574,7 @@ func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @tanh // CHECK-LABEL: func @tanh
func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -588,7 +588,7 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @complex(%real: memref<2x2xf32>, func @complex(%real: memref<2x2xf32>,
%imag: memref<2x2xf32>, %imag: memref<2x2xf32>,
%cplx: memref<2x2xcomplex<f32>>) { %cplx: memref<2x2xcomplex<f32>>) {
"xla_lhlo.complex"(%real, %imag, %cplx) "lmhlo.complex"(%real, %imag, %cplx)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex<f32>>) -> () : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex<f32>>) -> ()
return return
} }
@ -602,7 +602,7 @@ func @complex(%real: memref<2x2xf32>,
// CHECK-LABEL: func @real // CHECK-LABEL: func @real
func @real(%cplx: memref<2x2xcomplex<f32>>, func @real(%cplx: memref<2x2xcomplex<f32>>,
%real: memref<2x2xf32>) { %real: memref<2x2xf32>) {
"xla_lhlo.real"(%cplx, %real) "lmhlo.real"(%cplx, %real)
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> () : (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
return return
} }
@ -616,7 +616,7 @@ func @real(%cplx: memref<2x2xcomplex<f32>>,
// CHECK-LABEL: func @imag // CHECK-LABEL: func @imag
func @imag(%cplx: memref<2x2xcomplex<f32>>, func @imag(%cplx: memref<2x2xcomplex<f32>>,
%imag: memref<2x2xf32>) { %imag: memref<2x2xf32>) {
"xla_lhlo.imag"(%cplx, %imag) "lmhlo.imag"(%cplx, %imag)
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> () : (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
return return
} }
@ -629,7 +629,7 @@ func @imag(%cplx: memref<2x2xcomplex<f32>>,
// CHECK: func @slice(%[[IN:.*]]: memref<?x?xf32>, %[[OUT:.*]]: memref<?x?xf32>) // CHECK: func @slice(%[[IN:.*]]: memref<?x?xf32>, %[[OUT:.*]]: memref<?x?xf32>)
func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) { func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
"xla_lhlo.slice"(%operand, %result) { "lmhlo.slice"(%operand, %result) {
start_indices = dense<[0,1]> : tensor<2xi64>, start_indices = dense<[0,1]> : tensor<2xi64>,
limit_indices = dense<[2,3]> : tensor<2xi64>, limit_indices = dense<[2,3]> : tensor<2xi64>,
strides = dense<[1,1]> : tensor<2xi64> strides = dense<[1,1]> : tensor<2xi64>
@ -653,7 +653,7 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D // CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1) "lmhlo.reshape"(%arg0, %arg1)
: (memref<12x1x42xi32>, memref<12x42xi32>) -> () : (memref<12x1x42xi32>, memref<12x42xi32>) -> ()
return return
} }
@ -666,7 +666,7 @@ func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D // CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1) "lmhlo.reshape"(%arg0, %arg1)
: (memref<12x42x1x1xi32>, memref<12x42xi32>) -> () : (memref<12x42x1x1xi32>, memref<12x42xi32>) -> ()
return return
} }
@ -679,7 +679,7 @@ func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D // CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1) "lmhlo.reshape"(%arg0, %arg1)
: (memref<12x42xi32>, memref<12x1x42x1xi32>) -> () : (memref<12x42xi32>, memref<12x1x42x1xi32>) -> ()
return return
} }
@ -692,7 +692,7 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse // CHECK-LABEL: func @reverse
func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) { func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
"xla_lhlo.reverse"(%arg0, %arg1) { "lmhlo.reverse"(%arg0, %arg1) {
dimensions = dense<1> : tensor<1xi64> dimensions = dense<1> : tensor<1xi64>
} : (memref<2x3xf32>, memref<2x3xf32>) -> () } : (memref<2x3xf32>, memref<2x3xf32>) -> ()
return return
@ -710,15 +710,15 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64> // CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: strides = [2, 1]} // CHECK-SAME: strides = [2, 1]}
// With all atributes explicitly specified. // With all atributes explicitly specified.
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () "lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
// Dilation left unspecified, sets default dilation since linalg expects it. // Dilation left unspecified, sets default dilation since linalg expects it.
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: dilations = [1, 1] // CHECK-SAME: dilations = [1, 1]
// Padding is not set if it's zero. // Padding is not set if it's zero.
// CHECK-NOT: padding // CHECK-NOT: padding
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () "lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () "lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }

View File

@ -2,7 +2,7 @@
// CHECK-LABEL: func @static_memref_cast // CHECK-LABEL: func @static_memref_cast
func @static_memref_cast(%buf : memref<10x1x5xf32>) { func @static_memref_cast(%buf : memref<10x1x5xf32>) {
%0 = xla_lhlo.static_memref_cast %buf %0 = lmhlo.static_memref_cast %buf
: memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]> : memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]>
return return
} }
@ -38,7 +38,7 @@ func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
%size_Y = constant 50 : index %size_Y = constant 50 : index
%stride_X = constant 1 : index %stride_X = constant 1 : index
%stride_Y = constant 0 : index %stride_Y = constant 0 : index
%0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y] %0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]> : memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return return
} }

View File

@ -3,11 +3,11 @@
func @reduce(%arg: memref<100x10x5xf32>, func @reduce(%arg: memref<100x10x5xf32>,
%init: memref<f32>, %init: memref<f32>,
%result: memref<100x5xf32>) { %result: memref<100x5xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( { "lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>): ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res) "lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> () : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>} } ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> () : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
return return
@ -35,7 +35,7 @@ func @reduce(%arg: memref<100x10x5xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32> // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: } // CHECK: }
@ -49,11 +49,11 @@ func @reduce(%arg: memref<100x10x5xf32>,
func @reduce_no_outer_loop(%arg: memref<100xf32>, func @reduce_no_outer_loop(%arg: memref<100xf32>,
%init: memref<f32>, %init: memref<f32>,
%result: memref<1xf32>) { %result: memref<1xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( { "lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>): ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res) "lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> () : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[0]> : tensor<1xi64>} } ) {dimensions = dense<[0]> : tensor<1xi64>}
: (memref<100xf32>, memref<f32>, memref<1xf32>) -> () : (memref<100xf32>, memref<f32>, memref<1xf32>) -> ()
return return
@ -76,7 +76,7 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32> // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] // CHECK: scf.reduce.return [[ACC_RESULT]]
// CHECK: } // CHECK: }
@ -88,11 +88,11 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
func @dynamic_reduce(%arg: memref<?x?x?xf32>, func @dynamic_reduce(%arg: memref<?x?x?xf32>,
%init: memref<f32>, %init: memref<f32>,
%result: memref<?x?xf32>) { %result: memref<?x?xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( { "lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>): ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res) "lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> () : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>} } ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<?x?x?xf32>, memref<f32>, memref<?x?xf32>) -> () : (memref<?x?x?xf32>, memref<f32>, memref<?x?xf32>) -> ()
return return
@ -121,7 +121,7 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32> // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: } // CHECK: }
@ -135,11 +135,11 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
func @reduce_window(%arg: memref<112x112xf32>, func @reduce_window(%arg: memref<112x112xf32>,
%init: memref<f32>, %init: memref<f32>,
%result: memref<56x56xf32>) { %result: memref<56x56xf32>) {
"xla_lhlo.reduce_window"(%arg, %init, %result) ( { "lmhlo.reduce_window"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>): ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.maximum"(%lhs, %rhs, %res) "lmhlo.maximum"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> () : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
}) { }) {
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
window_dimensions = dense<[3, 3]> : tensor<2xi64>, window_dimensions = dense<[3, 3]> : tensor<2xi64>,
@ -189,7 +189,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32> // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: } // CHECK: }

View File

@ -4,7 +4,7 @@
// CHECK-LABEL: func @ceil // CHECK-LABEL: func @ceil
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
@ -12,7 +12,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point values}} // expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () "lmhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return return
} }
@ -20,7 +20,7 @@ func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @cos // CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
@ -28,7 +28,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @cos // CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) { func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> () "lmhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return return
} }
@ -36,7 +36,7 @@ func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}} // expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () "lmhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return return
} }
@ -44,7 +44,7 @@ func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @sin // CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
@ -52,7 +52,7 @@ func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sin // CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) { func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> () "lmhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return return
} }
@ -60,7 +60,7 @@ func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}} // expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () "lmhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return return
} }
@ -68,7 +68,7 @@ func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @add_memrefs // CHECK-LABEL: func @add_memrefs
func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () "lmhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return return
} }
@ -76,7 +76,7 @@ func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1
// CHECK-LABEL: func @abs_memref // CHECK-LABEL: func @abs_memref
func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () "lmhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -84,7 +84,7 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @convert_memref // CHECK-LABEL: func @convert_memref
func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () { func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> () "lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
return return
} }
@ -92,7 +92,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () { func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
// expected-error@+1{{requires the same shape for all operands}} // expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> () "lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
return return
} }
@ -100,7 +100,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
// CHECK-LABEL: func @exp // CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return return
} }
@ -108,7 +108,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @exp // CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) { func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> () "lmhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return return
} }
@ -116,7 +116,7 @@ func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}} // expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () "lmhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return return
} }
@ -124,7 +124,7 @@ func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @log_memref // CHECK-LABEL: func @log_memref
func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () "lmhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -132,7 +132,7 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @log_memref // CHECK-LABEL: func @log_memref
func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () { func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> () "lmhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return return
} }
@ -140,7 +140,7 @@ func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) ->
func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}} // expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () "lmhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return return
} }
@ -148,7 +148,7 @@ func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @neg_memref // CHECK-LABEL: func @neg_memref
func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () "lmhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -156,7 +156,7 @@ func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @rsqrt_memref // CHECK-LABEL: func @rsqrt_memref
func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () "lmhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -164,7 +164,7 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @rsqrt_memref // CHECK-LABEL: func @rsqrt_memref
func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () { func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> () "lmhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return return
} }
@ -172,7 +172,7 @@ func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>)
func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}} // expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () "lmhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return return
} }
@ -180,7 +180,7 @@ func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @sqrt_memref // CHECK-LABEL: func @sqrt_memref
func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () "lmhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -188,7 +188,7 @@ func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @sqrt_memref // CHECK-LABEL: func @sqrt_memref
func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () { func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> () "lmhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return return
} }
@ -196,7 +196,7 @@ func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -
func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}} // expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () "lmhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return return
} }
@ -204,7 +204,7 @@ func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @sign_memref // CHECK-LABEL: func @sign_memref
func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () "lmhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -212,7 +212,7 @@ func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @tanh_memref // CHECK-LABEL: func @tanh_memref
func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () "lmhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -220,7 +220,7 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @tanh_memref // CHECK-LABEL: func @tanh_memref
func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () { func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> () "lmhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return return
} }
@ -228,15 +228,15 @@ func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -
func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}} // expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () "lmhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return return
} }
// ----- // -----
func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}} // expected-error@+1{{'lmhlo.tanh' op requires all operands to have the same type}}
"xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> () "lmhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
return return
} }
@ -244,7 +244,7 @@ func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
// CHECK-LABEL: func @add_memref // CHECK-LABEL: func @add_memref
func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () "lmhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -252,7 +252,7 @@ func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @div_memref // CHECK-LABEL: func @div_memref
func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () "lmhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -260,7 +260,7 @@ func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @max_memref // CHECK-LABEL: func @max_memref
func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () "lmhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -268,7 +268,7 @@ func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @min_memref // CHECK-LABEL: func @min_memref
func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () "lmhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -276,7 +276,7 @@ func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @mul_memref // CHECK-LABEL: func @mul_memref
func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () "lmhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -284,7 +284,7 @@ func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @sub_memref // CHECK-LABEL: func @sub_memref
func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () "lmhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -292,7 +292,7 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @and_memref // CHECK-LABEL: func @and_memref
func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return return
} }
@ -300,7 +300,7 @@ func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32
// CHECK-LABEL: func @and_memref // CHECK-LABEL: func @and_memref
func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return return
} }
@ -308,7 +308,7 @@ func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>)
func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -316,7 +316,7 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @or_memref // CHECK-LABEL: func @or_memref
func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return return
} }
@ -324,7 +324,7 @@ func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>
// CHECK-LABEL: func @or_memref // CHECK-LABEL: func @or_memref
func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return return
} }
@ -332,7 +332,7 @@ func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -
func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -340,7 +340,7 @@ func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>
// CHECK-LABEL: func @xor_memref // CHECK-LABEL: func @xor_memref
func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return return
} }
@ -348,7 +348,7 @@ func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32
// CHECK-LABEL: func @xor_memref // CHECK-LABEL: func @xor_memref
func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return return
} }
@ -356,7 +356,7 @@ func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>)
func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return return
} }
@ -364,7 +364,7 @@ func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @broadcast_in_dim_memref // CHECK-LABEL: func @broadcast_in_dim_memref
func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () { func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> () "lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
return return
} }
@ -372,7 +372,7 @@ func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -
// CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref // CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref
func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi32>) -> () { func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> () "lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
return return
} }
@ -381,10 +381,10 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi
// CHECK-LABEL: func @reduce_memref // CHECK-LABEL: func @reduce_memref
func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () { func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
"xla_lhlo.reduce"(%input, %init, %out) ( { "lmhlo.reduce"(%input, %init, %out) ( {
^bb0(%arg1: memref<f32>, %arg2: memref<f32>, %result: memref<f32>): ^bb0(%arg1: memref<f32>, %arg2: memref<f32>, %result: memref<f32>):
"xla_lhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> () "lmhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref<f32>, memref<1xf32>) -> () } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref<f32>, memref<1xf32>) -> ()
return return
} }
@ -393,14 +393,14 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
// CHECK-LABEL: func @fusion_memref // CHECK-LABEL: func @fusion_memref
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () { func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.fusion"() ( { "lmhlo.fusion"() ( {
%0 = tensor_load %input1 : memref<10xf32> %0 = tensor_load %input1 : memref<10xf32>
%1 = tensor_load %input2 : memref<10xf32> %1 = tensor_load %input2 : memref<10xf32>
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%3 = tensor_load %input3 : memref<10xf32> %3 = tensor_load %input3 : memref<10xf32>
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
tensor_store %4, %out : memref<10xf32> tensor_store %4, %out : memref<10xf32>
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} ) : () -> () } ) : () -> ()
return return
} }
@ -409,18 +409,18 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
// CHECK-LABEL: func @case_memref // CHECK-LABEL: func @case_memref
func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () { func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
"xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { "lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
^bb0(%arg0: memref<f32>): ^bb0(%arg0: memref<f32>):
"xla_lhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> () "lmhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
}, { }, {
^bb0(%arg0: memref<f32>): ^bb0(%arg0: memref<f32>):
"xla_lhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> () "lmhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
}, { }, {
^bb0(%arg0: memref<f32>): ^bb0(%arg0: memref<f32>):
"xla_lhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> () "lmhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }
) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>} ) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>}
: (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> () : (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
@ -430,7 +430,7 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr
// ----- // -----
func @static_memref_cast(%in: memref<10x1xf32>) { func @static_memref_cast(%in: memref<10x1xf32>) {
%out = xla_lhlo.static_memref_cast %in %out = lmhlo.static_memref_cast %in
: memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]> : memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]>
return return
} }
@ -440,7 +440,7 @@ func @static_memref_cast(%in: memref<10x1xf32>) {
func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) { func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
// expected-error @+1 {{operand must have static shape}} // expected-error @+1 {{operand must have static shape}}
%out = xla_lhlo.static_memref_cast %in %out = lmhlo.static_memref_cast %in
: memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]> : memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]>
return return
} }
@ -449,7 +449,7 @@ func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) { func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
// expected-error @+1 {{result must have static shape}} // expected-error @+1 {{result must have static shape}}
%out = xla_lhlo.static_memref_cast %in %out = lmhlo.static_memref_cast %in
: memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]> : memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]>
return return
} }
@ -459,7 +459,7 @@ func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
func @dynamic_memref_cast(%in: memref<?xf32>) { func @dynamic_memref_cast(%in: memref<?xf32>) {
%size = constant 10 : index %size = constant 10 : index
%step = constant 1 : index %step = constant 1 : index
%out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] %out = lmhlo.dynamic_memref_cast %in(%size)[%step]
: memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]> : memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]>
return return
} }
@ -471,7 +471,7 @@ func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
// expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}} // expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}}
%size = constant 10 : index %size = constant 10 : index
%step = constant 1 : index %step = constant 1 : index
%out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] %out = lmhlo.dynamic_memref_cast %in(%size)[%step]
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]> : memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return return
} }
@ -483,19 +483,19 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
// CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>, // CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>,
// CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref<?xi32> // CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref<?xi32>
// CHECK-NEXT: [[DYN_VEC:%.*]] = xla_lhlo.reshape_memref_cast [[UNRANKED]] // CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]]
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32> // CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
%dyn_vec = xla_lhlo.reshape_memref_cast %unranked(%shape1) %dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1)
: (memref<*xf32>, memref<1xi32>) -> memref<?xf32> : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
// CHECK-NEXT: [[DYN_MAT:%.*]] = xla_lhlo.reshape_memref_cast [[DYN_VEC]] // CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]]
// CHECK-SAME: : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32> // CHECK-SAME: : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
%dyn_mat = xla_lhlo.reshape_memref_cast %dyn_vec(%shape2) %dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2)
: (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32> : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
// CHECK-NEXT: {{%.*}} = xla_lhlo.reshape_memref_cast [[DYN_MAT]] // CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]]
// CHECK-SAME: : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32> // CHECK-SAME: : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
%new_unranked = xla_lhlo.reshape_memref_cast %dyn_mat(%shape3) %new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3)
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32> : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
return return
} }
@ -505,7 +505,7 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
func @reshape_memref_cast_element_type_mismatch( func @reshape_memref_cast_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) { %buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}} // expected-error @+1 {{element types of source and destination memref types should be the same}}
xla_lhlo.reshape_memref_cast %buf(%shape) lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<1xi32>) -> memref<?xi32> : (memref<*xf32>, memref<1xi32>) -> memref<?xi32>
} }
@ -514,7 +514,7 @@ func @reshape_memref_cast_element_type_mismatch(
func @reshape_memref_cast_dst_ranked_shape_unranked( func @reshape_memref_cast_dst_ranked_shape_unranked(
%buf: memref<*xf32>, %shape: memref<?xi32>) { %buf: memref<*xf32>, %shape: memref<?xi32>) {
// expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}} // expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}}
xla_lhlo.reshape_memref_cast %buf(%shape) lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<?xi32>) -> memref<?xf32> : (memref<*xf32>, memref<?xi32>) -> memref<?xf32>
return return
} }
@ -524,7 +524,7 @@ func @reshape_memref_cast_dst_ranked_shape_unranked(
func @reshape_memref_cast_dst_shape_rank_mismatch( func @reshape_memref_cast_dst_shape_rank_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) { %buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{length of shape operand differs from the result's memref rank}} // expected-error @+1 {{length of shape operand differs from the result's memref rank}}
xla_lhlo.reshape_memref_cast %buf(%shape) lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32> : (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32>
return return
} }
@ -535,7 +535,7 @@ func @reshape_memref_cast_affine_map_is_not_identity(
%buf: memref<4x4xf32, offset: 0, strides: [3, 2]>, %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>,
%shape: memref<1xi32>) { %shape: memref<1xi32>) {
// expected-error @+1 {{operand memref type should have identity affine map}} // expected-error @+1 {{operand memref type should have identity affine map}}
xla_lhlo.reshape_memref_cast %buf(%shape) lmhlo.reshape_memref_cast %buf(%shape)
: (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>) : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>)
-> memref<8xf32> -> memref<8xf32>
return return
@ -545,7 +545,7 @@ func @reshape_memref_cast_affine_map_is_not_identity(
// CHECK-LABEL: func @atan2_memrefs // CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -553,7 +553,7 @@ func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref
// CHECK-LABEL: func @atan2_memrefs // CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () { func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> () "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return return
} }
@ -561,7 +561,7 @@ func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>
func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}} // expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return return
} }
@ -569,7 +569,7 @@ func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref
// CHECK-LABEL: func @bitcast_convert_memrefs // CHECK-LABEL: func @bitcast_convert_memrefs
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () { func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> () "lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
return return
} }
@ -577,7 +577,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) ->
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () { func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () {
// expected-error@+1{{requires the same shape for all operands}} // expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> () "lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
return return
} }
@ -585,7 +585,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) ->
// CHECK-LABEL: func @clz_memrefs // CHECK-LABEL: func @clz_memrefs
func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () "lmhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return return
} }
@ -593,7 +593,7 @@ func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @expm1_memrefs // CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () "lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -601,7 +601,7 @@ func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @expm1_memrefs // CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () { func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> () "lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return return
} }
@ -609,7 +609,7 @@ func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f3
// CHECK-LABEL: func @floor_memrefs // CHECK-LABEL: func @floor_memrefs
func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () "lmhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -617,7 +617,7 @@ func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point values}} // expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () "lmhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return return
} }
@ -625,7 +625,7 @@ func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @imag_memrefs // CHECK-LABEL: func @imag_memrefs
func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () { func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> () "lmhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
return return
} }
@ -633,7 +633,7 @@ func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> ()
func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error@+1{{must be memref of complex-type values}} // expected-error@+1{{must be memref of complex-type values}}
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () "lmhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -641,7 +641,7 @@ func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @real_memrefs // CHECK-LABEL: func @real_memrefs
func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () { func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> () "lmhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
return return
} }
@ -649,7 +649,7 @@ func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> ()
func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error@+1{{must be memref of complex-type values}} // expected-error@+1{{must be memref of complex-type values}}
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () "lmhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -657,7 +657,7 @@ func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @is_finite_memrefs // CHECK-LABEL: func @is_finite_memrefs
func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () { func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
"xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> () "lmhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
return return
} }
@ -665,7 +665,7 @@ func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
// CHECK-LABEL: func @log1p_memrefs // CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () "lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -673,7 +673,7 @@ func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @log1p_memrefs // CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () { func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> () "lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return return
} }
@ -681,7 +681,7 @@ func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f3
func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}} // expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () "lmhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return return
} }
@ -689,7 +689,7 @@ func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @not_memrefs // CHECK-LABEL: func @not_memrefs
func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () "lmhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return return
} }
@ -697,7 +697,7 @@ func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @not_memrefs // CHECK-LABEL: func @not_memrefs
func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () { func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> () "lmhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
return return
} }
@ -705,7 +705,7 @@ func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () "lmhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -713,7 +713,7 @@ func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @popcnt_memrefs // CHECK-LABEL: func @popcnt_memrefs
func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () "lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return return
} }
@ -721,7 +721,7 @@ func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () "lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -729,7 +729,7 @@ func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @reduce_precision_memrefs // CHECK-LABEL: func @reduce_precision_memrefs
func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> () "lmhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -737,7 +737,7 @@ func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) ->
// CHECK-LABEL: func @round_memrefs // CHECK-LABEL: func @round_memrefs
func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () "lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -745,7 +745,7 @@ func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point values}} // expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () "lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return return
} }
@ -753,7 +753,7 @@ func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @shift_left_memrefs // CHECK-LABEL: func @shift_left_memrefs
func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () "lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return return
} }
@ -761,7 +761,7 @@ func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: m
func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () "lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -769,7 +769,7 @@ func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: m
// CHECK-LABEL: func @shift_right_arithmetic_memrefs // CHECK-LABEL: func @shift_right_arithmetic_memrefs
func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () "lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return return
} }
@ -777,7 +777,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>,
func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () "lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -785,7 +785,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>,
// CHECK-LABEL: func @shift_right_logical_memrefs // CHECK-LABEL: func @shift_right_logical_memrefs
func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () "lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return return
} }
@ -793,7 +793,7 @@ func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %a
func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () "lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return return
} }
@ -801,14 +801,14 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
// CHECK-LABEL: func @all_reduce_memrefs // CHECK-LABEL: func @all_reduce_memrefs
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () { func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({ "lmhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = mhlo.maximum %lhs, %rhs : tensor<f32> %max = mhlo.maximum %lhs, %rhs : tensor<f32>
"mhlo.return"(%max) : (tensor<f32>) -> () "mhlo.return"(%max) : (tensor<f32>) -> ()
}) })
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> () { replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({ "lmhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = mhlo.maximum %lhs, %rhs : tensor<f32> %max = mhlo.maximum %lhs, %rhs : tensor<f32>
"mhlo.return"(%max) : (tensor<f32>) -> () "mhlo.return"(%max) : (tensor<f32>) -> ()
@ -826,11 +826,11 @@ func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> ()
// CHECK-LABEL: func @collective_permute_memrefs // CHECK-LABEL: func @collective_permute_memrefs
func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () { func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
"xla_lhlo.collective_permute"(%arg0, %arg_out) { "lmhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
} : (memref<128x32xf32>, memref<128x32xf32>) -> () } : (memref<128x32xf32>, memref<128x32xf32>) -> ()
"xla_lhlo.collective_permute"(%arg0, %arg_out) { "lmhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
channel_id = { handle = 5 : i64, type = 2 : i64 } channel_id = { handle = 5 : i64, type = 2 : i64 }
} : (memref<128x32xf32>, memref<128x32xf32>) -> () } : (memref<128x32xf32>, memref<128x32xf32>) -> ()
@ -841,7 +841,7 @@ func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128
// CHECK-LABEL: func @fft_memrefs // CHECK-LABEL: func @fft_memrefs
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () { func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
"xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> () "lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
return return
} }
@ -852,7 +852,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>, %arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
%grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>, %grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>,
%grad_offset: memref<8xf32>) -> () { %grad_offset: memref<8xf32>) -> () {
"xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} "lmhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
return return
@ -863,7 +863,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>,
// CHECK-LABEL: func @batch_norm_inference_memrefs // CHECK-LABEL: func @batch_norm_inference_memrefs
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () { %arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
"xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} "lmhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> () : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
return return
} }
@ -874,7 +874,7 @@ func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>, %output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>,
%batch_var: memref<8xf32>) -> () { %batch_var: memref<8xf32>) -> () {
"xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} "lmhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
return return
} }
@ -883,8 +883,8 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
// CHECK-LABEL: func @cholesky_memrefs // CHECK-LABEL: func @cholesky_memrefs
func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () { func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () {
"xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () "lmhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
"xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () "lmhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
return return
} }
@ -892,7 +892,7 @@ func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291x
// CHECK-LABEL: func @infeed_memrefs // CHECK-LABEL: func @infeed_memrefs
func @infeed_memrefs(%arg_out: memref<3xf32>) -> () { func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
"xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> () "lmhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
return return
} }
@ -900,7 +900,7 @@ func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
// CHECK-LABEL: func @outfeed_memrefs // CHECK-LABEL: func @outfeed_memrefs
func @outfeed_memrefs(%arg0: memref<3xf32>) -> () { func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
"xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> () "lmhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
return return
} }
@ -908,7 +908,7 @@ func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
// CHECK-LABEL: func @replica_id_memrefs // CHECK-LABEL: func @replica_id_memrefs
func @replica_id_memrefs(%arg_out: memref<ui32>) -> () { func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
"xla_lhlo.replica_id"(%arg_out) : (memref<ui32>) -> () "lmhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
return return
} }
@ -916,7 +916,7 @@ func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
// CHECK-LABEL: func @triangular_solve_memrefs // CHECK-LABEL: func @triangular_solve_memrefs
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () { func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
"xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} "lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> () : (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
return return
} }
@ -925,9 +925,9 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %
// CHECK-LABEL: func @while_memrefs // CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () { func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
"xla_lhlo.while"(%arg0, %arg_out) ( "lmhlo.while"(%arg0, %arg_out) (
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () }, { ^bb0(%arg: memref<i64>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "xla_lhlo.terminator"() : () -> () } { ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "lmhlo.terminator"() : () -> () }
) : (memref<i64>, memref<i64>) -> () ) : (memref<i64>, memref<i64>) -> ()
return return
} }
@ -936,9 +936,9 @@ func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
// CHECK-LABEL: func @while_memrefs // CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>) -> () { func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>) -> () {
"xla_lhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) ( "lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) (
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () }, { ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "xla_lhlo.terminator"() : () -> () } { ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () }
) : (memref<i64>, memref<5xf32>, memref<i64>, memref<5xf32>) -> () ) : (memref<i64>, memref<5xf32>, memref<i64>, memref<5xf32>) -> ()
return return
} }
@ -947,7 +947,7 @@ func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<
// CHECK-LABEL: func @bitcast_memrefs // CHECK-LABEL: func @bitcast_memrefs
func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () { func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
"xla_lhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> () "lmhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
return return
} }
@ -956,7 +956,7 @@ func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
// CHECK-LABEL: func @scatter_memrefs // CHECK-LABEL: func @scatter_memrefs
func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>, func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>,
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () { %updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({ "lmhlo.scatter" (%input, %indices, %updates, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
%add = mhlo.add %lhs, %rhs : tensor<f32> %add = mhlo.add %lhs, %rhs : tensor<f32>
"mhlo.return"(%add) : (tensor<f32>) -> () "mhlo.return"(%add) : (tensor<f32>) -> ()
@ -977,7 +977,7 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
// CHECK-LABEL: func @map_memrefs // CHECK-LABEL: func @map_memrefs
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () { func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ "lmhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>): ^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = mhlo.add %a, %b : tensor<f32> %c = mhlo.add %a, %b : tensor<f32>
"mhlo.return"(%c) : (tensor<f32>) -> () "mhlo.return"(%c) : (tensor<f32>) -> ()
@ -989,7 +989,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () { func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () {
// expected-error@+1{{requires the same shape for all operands}} // expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ "lmhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>): ^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = mhlo.add %a, %b : tensor<f32> %c = mhlo.add %a, %b : tensor<f32>
"mhlo.return"(%c) : (tensor<f32>) -> () "mhlo.return"(%c) : (tensor<f32>) -> ()
@ -1001,7 +1001,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
// CHECK-LABEL: func @rng_get_and_update_state_memrefs // CHECK-LABEL: func @rng_get_and_update_state_memrefs
func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () { func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
"xla_lhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> () "lmhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> ()
return return
} }
@ -1010,7 +1010,7 @@ func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
// CHECK-LABEL: func @sort_memrefs // CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>): ^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
@ -1023,7 +1023,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
// CHECK-LABEL: func @sort_memrefs // CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>): ^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
@ -1036,7 +1036,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
// CHECK-LABEL: func @sort_memrefs // CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>): ^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()