Rename xla_lhlo dialect into lmhlo
Following on the plan of isolating the compiler/mlir/hlo directory. Another xla_lhlo dialect will be created under compiler/mlir/xla/ later. PiperOrigin-RevId: 320210326
This commit is contained in:
parent
b076e018a8
commit
7c4a5d62b5
|
@ -35,18 +35,18 @@ class OpBuilder;
|
|||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc"
|
||||
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
|
||||
class XlaLhloDialect : public Dialect {
|
||||
class LmhloDialect : public Dialect {
|
||||
public:
|
||||
explicit XlaLhloDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "xla_lhlo"; }
|
||||
explicit LmhloDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "lmhlo"; }
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
||||
|
|
|
@ -38,8 +38,8 @@ include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInte
|
|||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
|
||||
def LHLO_Dialect : Dialect {
|
||||
let name = "xla_lhlo";
|
||||
let cppNamespace = "xla_lhlo";
|
||||
let name = "lmhlo";
|
||||
let cppNamespace = "lmhlo";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -253,7 +253,7 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
|
|||
|
||||
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example,
|
||||
// A tuple-like pattern match syntax could work:
|
||||
// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
|
||||
// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
|
||||
// ...
|
||||
// }, {
|
||||
// ...
|
||||
|
@ -337,7 +337,7 @@ def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
|||
Example:
|
||||
```mlir
|
||||
%buf_transformed =
|
||||
xla_lhlo.static_memref_cast %buf
|
||||
lmhlo.static_memref_cast %buf
|
||||
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
|
||||
|
||||
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
|
||||
|
@ -379,7 +379,7 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
|
|||
Example:
|
||||
```mlir
|
||||
%buf_transformed =
|
||||
xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
|
||||
lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
|
||||
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
// The result of the op is a type-erased memref with `[%size_X, %size_Y]`
|
||||
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited
|
||||
|
|
|
@ -34,7 +34,7 @@ using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
|
|||
#define MAP_HLO_TO_LHLO(OpName) \
|
||||
template <> \
|
||||
struct HloToLhloOpImpl<mhlo::OpName> { \
|
||||
using Type = xla_lhlo::OpName; \
|
||||
using Type = lmhlo::OpName; \
|
||||
}
|
||||
|
||||
MAP_HLO_TO_LHLO(AbsOp);
|
||||
|
|
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace impl {
|
||||
|
||||
// A struct to map LhloBinaryOpTy type to the corresponding floating-point and
|
||||
|
@ -33,32 +33,32 @@ template <typename LhloBinaryOpTy>
|
|||
struct LhloToScalarOp;
|
||||
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::AddOp> {
|
||||
struct LhloToScalarOp<lmhlo::AddOp> {
|
||||
using FOp = ::mlir::AddFOp;
|
||||
using IOp = ::mlir::AddIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::CompareOp> {
|
||||
struct LhloToScalarOp<lmhlo::CompareOp> {
|
||||
using FOp = ::mlir::CmpFOp;
|
||||
using IOp = ::mlir::CmpIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::DivOp> {
|
||||
struct LhloToScalarOp<lmhlo::DivOp> {
|
||||
using FOp = ::mlir::DivFOp;
|
||||
using IOp = ::mlir::SignedDivIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::MulOp> {
|
||||
struct LhloToScalarOp<lmhlo::MulOp> {
|
||||
using FOp = ::mlir::MulFOp;
|
||||
using IOp = ::mlir::MulIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::RemOp> {
|
||||
struct LhloToScalarOp<lmhlo::RemOp> {
|
||||
using FOp = ::mlir::RemFOp;
|
||||
using IOp = ::mlir::SignedRemIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::SubOp> {
|
||||
struct LhloToScalarOp<lmhlo::SubOp> {
|
||||
using FOp = ::mlir::SubFOp;
|
||||
using IOp = ::mlir::SubIOp;
|
||||
};
|
||||
|
@ -116,8 +116,9 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
|
|||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
if (element_type.isa<FloatType>()) {
|
||||
|
@ -125,7 +126,7 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
|
|||
loc, result_types, args, b);
|
||||
}
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
// xla_lhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
|
||||
// lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
|
||||
Value lhs = args[0];
|
||||
auto integer_type = element_type.dyn_cast<IntegerType>();
|
||||
|
||||
|
@ -133,15 +134,16 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
|
|||
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
|
||||
lhs, zero_intval);
|
||||
auto neg_val = b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
|
||||
auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
|
||||
return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
||||
loc, result_types, args, b);
|
||||
|
@ -205,30 +207,33 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc,
|
|||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return args.front();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args,
|
||||
|
@ -236,21 +241,23 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
|
|||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RealOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ImagOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type sourceType = args.front().getType();
|
||||
|
@ -288,8 +295,9 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
|||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
// Dot Op converter from lhlo to affine only accepts float and integer types.
|
||||
const auto& lhs = args[0];
|
||||
|
@ -312,16 +320,18 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
|
|||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SinOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
|
||||
loc, result_types, args, b);
|
||||
|
@ -361,38 +371,40 @@ struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
|||
};
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return XlaCompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
|
||||
result_types, args,
|
||||
b);
|
||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return XlaCompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
|
||||
result_types, args,
|
||||
b);
|
||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
if (element_type.isa<FloatType>()) {
|
||||
|
@ -400,27 +412,28 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
|
|||
loc, result_types, args, b);
|
||||
}
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
// xla_lhlo.neg(x, result) -> result = sub(0, x)
|
||||
// lmhlo.neg(x, result) -> result = sub(0, x)
|
||||
Value lhs = args[0];
|
||||
auto integer_type = element_type.dyn_cast<IntegerType>();
|
||||
|
||||
auto zero_intval =
|
||||
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||
return b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
|
||||
return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RsqrtOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
|
||||
|
@ -428,8 +441,9 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
|
|||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
if (element_type.isa<FloatType>()) {
|
||||
|
@ -442,16 +456,18 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
|
|||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SqrtOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
|
||||
loc, result_types, args, b);
|
||||
|
@ -460,10 +476,10 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
|
|||
} // namespace impl
|
||||
|
||||
struct XlaOpToStdScalarOp {
|
||||
// Implementation for LHLO ops except xla_lhlo::CompareOp.
|
||||
// Implementation for LHLO ops except lmhlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
|
||||
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
|
||||
std::false_type>::value>>
|
||||
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||
|
@ -475,7 +491,7 @@ struct XlaOpToStdScalarOp {
|
|||
// Implementation for HLO ops except mhlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
|
||||
!std::is_same<LhloOpTy, std::false_type>::value>>
|
||||
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
|
||||
|
@ -483,13 +499,13 @@ struct XlaOpToStdScalarOp {
|
|||
args, b);
|
||||
}
|
||||
|
||||
// Implementation for xla_lhlo::CompareOp.
|
||||
// Implementation for lmhlo::CompareOp.
|
||||
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
|
||||
LhloOpTy, xla_lhlo::CompareOp>::value>>
|
||||
static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
LhloOpTy, lmhlo::CompareOp>::value>>
|
||||
static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
|
||||
|
@ -500,12 +516,12 @@ struct XlaOpToStdScalarOp {
|
|||
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
|
||||
|
|
|
@ -60,7 +60,7 @@ std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
|
|||
|
||||
} // namespace mhlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
|
||||
// Lowers from LHLO dialect to Affine dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass();
|
||||
|
@ -92,7 +92,7 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
|
|||
// Lowers from LHLO dialect to parallel loops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
|
|
@ -75,14 +75,14 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
|
|||
|
||||
} // namespace mhlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
|
||||
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
|
||||
void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
|
||||
LLVMTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
|
||||
namespace xla_chlo {
|
||||
|
||||
|
|
|
@ -21,4 +21,4 @@ limitations under the License.
|
|||
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
|
||||
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
|
||||
xla_chlo_ops;
|
||||
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;
|
||||
static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops;
|
||||
|
|
|
@ -46,9 +46,9 @@ limitations under the License.
|
|||
|
||||
namespace mlir {
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc"
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
|
||||
XlaLhloDialect::XlaLhloDialect(MLIRContext *context)
|
||||
LmhloDialect::LmhloDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
|
@ -138,5 +138,5 @@ void FusionOp::build(OpBuilder &builder, OperationState &result,
|
|||
FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
|
||||
}
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -44,7 +44,7 @@ template <typename T>
|
|||
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
|
||||
using StdReturnOpConverter =
|
||||
detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
|
||||
xla_lhlo::CopyOp, true>;
|
||||
lmhlo::CopyOp, true>;
|
||||
|
||||
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||
Value shape_operand,
|
||||
|
@ -149,7 +149,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
|||
|
||||
Value transformed_operand =
|
||||
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
||||
rewriter.create<xla_lhlo::BroadcastInDimOp>(
|
||||
rewriter.create<lmhlo::BroadcastInDimOp>(
|
||||
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
|
||||
|
||||
rewriter.replaceOp(op, {resultBuffer});
|
||||
|
@ -161,7 +161,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
|||
// Inserts dynamic memref to change the layout of the memref to put 0-stride
|
||||
// and size of the target dimension if size-1 dimension expansion is
|
||||
// necessary.
|
||||
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
|
||||
lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
|
||||
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
auto loc = op.getLoc();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
|
@ -214,7 +214,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
|||
makeStridedLinearLayoutMap(dynamic_layout,
|
||||
/*offset=*/0, b->getContext()));
|
||||
|
||||
auto transformed_operand = b->create<xla_lhlo::DynamicMemRefCastOp>(
|
||||
auto transformed_operand = b->create<lmhlo::DynamicMemRefCastOp>(
|
||||
loc, type_erased_memref_type, operand, sizes, strides);
|
||||
return transformed_operand;
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ struct HloToLhloDynamicReshapeConverter
|
|||
return failure();
|
||||
}
|
||||
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<xla_lhlo::ReshapeMemRefCastOp>(
|
||||
rewriter.replaceOpWithNewOp<lmhlo::ReshapeMemRefCastOp>(
|
||||
op, result_type, adaptor.operand(), adaptor.output_shape());
|
||||
return success();
|
||||
}
|
||||
|
@ -266,8 +266,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
|||
buffer_args.push_back(
|
||||
InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
|
||||
}
|
||||
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
|
||||
loc, llvm::None, buffer_args, op.getAttrs());
|
||||
auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
|
||||
op.getAttrs());
|
||||
|
||||
// Copy over the operations inside the region.
|
||||
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
|
||||
|
@ -292,7 +292,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
|||
}
|
||||
// Insert terminator at the end.
|
||||
rewriter.setInsertionPointToEnd(&entry_block);
|
||||
rewriter.create<xla_lhlo::TerminatorOp>(loc);
|
||||
rewriter.create<lmhlo::TerminatorOp>(loc);
|
||||
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
|
||||
|
@ -321,8 +321,8 @@ class HloToLhloTensorStoreOpConverter
|
|||
LogicalResult matchAndRewrite(
|
||||
mlir::TensorStoreOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
|
||||
op, llvm::None, operands.front(), operands.back());
|
||||
rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
|
||||
operands.back());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -336,7 +336,7 @@ class HloToLhloTensorStoreOpConverter
|
|||
// %arg1: memref<2x2xf32>,
|
||||
// %arg2: memref<2x2xf32>,
|
||||
// %arg3: memref<2x2xf32>) {
|
||||
// "xla_lhlo.fusion"() ({
|
||||
// "lmhlo.fusion"() ({
|
||||
// %0 = tensor_load %arg1 : memref<2x2xf32>
|
||||
// %1 = tensor_load %arg2 : memref<2x2xf32>
|
||||
// %2 = "mhlo.add"(%0, %1) :
|
||||
|
@ -345,7 +345,7 @@ class HloToLhloTensorStoreOpConverter
|
|||
// %4 = "mhlo.multiply"(%2, %3) :
|
||||
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// tensor_store %4, %arg3 : memref<2x2xf32>
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// "lmhlo.terminator"() : () -> ()
|
||||
// }) : () -> ()
|
||||
// return
|
||||
// }
|
||||
|
@ -355,13 +355,13 @@ class HloToLhloTensorStoreOpConverter
|
|||
// %arg1: memref<2x2xf32>,
|
||||
// %arg2: memref<2x2xf32>,
|
||||
// %arg3: memref<2x2xf32>) {
|
||||
// "xla_lhlo.fusion"() ( {
|
||||
// "lmhlo.fusion"() ( {
|
||||
// %0 = alloc() : memref<2x2xf32>
|
||||
// "xla_lhlo.add"(%arg1, %arg2, %0) :
|
||||
// "lmhlo.add"(%arg1, %arg2, %0) :
|
||||
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// "xla_lhlo.multiply"(%0, %arg0, %arg3) :
|
||||
// "lmhlo.multiply"(%0, %arg0, %arg3) :
|
||||
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// "lmhlo.terminator"() : () -> ()
|
||||
// }) : () -> ()
|
||||
// return
|
||||
// }
|
||||
|
@ -382,13 +382,13 @@ class HloToLhloTensorStoreOpConverter
|
|||
// %arg2: memref<4xf32>) {
|
||||
// %0 = alloc() : memref<4xf32>
|
||||
|
||||
// "xla_lhlo.maximum"(%arg0, %arg1, %0) :
|
||||
// "lmhlo.maximum"(%arg0, %arg1, %0) :
|
||||
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||
// %1 = alloc() : memref<4xf32>
|
||||
// "xla_lhlo.add"(%arg0, %0, %1) :
|
||||
// "lmhlo.add"(%arg0, %0, %1) :
|
||||
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||
// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// "lmhlo.terminator"() : () -> ()
|
||||
// }
|
||||
|
||||
struct HloLegalizeToLhlo
|
||||
|
@ -406,7 +406,7 @@ struct HloLegalizeToLhlo
|
|||
OwningRewritePatternList patterns;
|
||||
auto& context = getContext();
|
||||
ConversionTarget target(context);
|
||||
target.addLegalDialect<xla_lhlo::XlaLhloDialect>();
|
||||
target.addLegalDialect<lmhlo::LmhloDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalOp<ModuleOp>();
|
||||
target.addIllegalOp<mlir::TensorLoadOp>();
|
||||
|
@ -441,12 +441,12 @@ struct HloLegalizeToLhlo
|
|||
&converter, &patterns);
|
||||
if (results_escape_function) {
|
||||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
|
||||
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp,
|
||||
/*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
|
||||
&converter, &patterns);
|
||||
} else {
|
||||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
|
||||
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp,
|
||||
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
|
||||
&converter, &patterns);
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
// Removes LHLO copy operations that copy from allocated buffers to block
|
||||
|
@ -34,7 +34,7 @@ struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
|
|||
void runOnOperation() override {
|
||||
llvm::SmallVector<mlir::Operation*, 2> eraseList;
|
||||
auto operation = getOperation();
|
||||
operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) {
|
||||
operation->walk([&](mlir::lmhlo::CopyOp copyOp) {
|
||||
// If this region contains more than one block, then ignore this copy
|
||||
// operation.
|
||||
if (copyOp.getParentRegion()->getBlocks().size() > 1) {
|
||||
|
@ -101,5 +101,5 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass() {
|
|||
static PassRegistration<LhloCopyRemoval> copy_removal_pass(
|
||||
"lhlo-copy-removal", "Removes redundant LHLO copy operations");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -27,7 +27,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
using linalg::LinalgOp;
|
||||
|
@ -147,5 +147,5 @@ static PassRegistration<LhloFuseLinalg> legalize_pass(
|
|||
"lhlo-fuse-linalg",
|
||||
"Greedily fuse linalg ops obtained after LHLO lowering.");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -28,7 +28,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
// Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
|
||||
|
@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
|
|||
auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
|
||||
auto result =
|
||||
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
|
||||
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<DotOp>(
|
||||
Value op_result = lmhlo::XlaOpToStdScalarOp::map<DotOp>(
|
||||
op, element_type, {l, r, result}, &builder);
|
||||
map_status = success(op_result != nullptr);
|
||||
if (failed(map_status)) return;
|
||||
|
@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
|||
ValueRange induction_vars) {
|
||||
auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
|
||||
auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
|
||||
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
|
||||
Value op_result = lmhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
|
||||
op, element_type, {l, r}, &builder);
|
||||
map_status = success(op_result != nullptr);
|
||||
if (failed(map_status)) return;
|
||||
|
@ -127,13 +127,13 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
|
|||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
BinaryOpConverter<xla_lhlo::AddOp>,
|
||||
BinaryOpConverter<xla_lhlo::AndOp>,
|
||||
BinaryOpConverter<xla_lhlo::DivOp>,
|
||||
BinaryOpConverter<xla_lhlo::MaxOp>,
|
||||
BinaryOpConverter<xla_lhlo::MinOp>,
|
||||
BinaryOpConverter<xla_lhlo::MulOp>,
|
||||
BinaryOpConverter<xla_lhlo::SubOp>,
|
||||
BinaryOpConverter<lmhlo::AddOp>,
|
||||
BinaryOpConverter<lmhlo::AndOp>,
|
||||
BinaryOpConverter<lmhlo::DivOp>,
|
||||
BinaryOpConverter<lmhlo::MaxOp>,
|
||||
BinaryOpConverter<lmhlo::MinOp>,
|
||||
BinaryOpConverter<lmhlo::MulOp>,
|
||||
BinaryOpConverter<lmhlo::SubOp>,
|
||||
DotOpConverter>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -157,5 +157,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() {
|
|||
static PassRegistration<LhloLegalizeToAffine> legalize_pass(
|
||||
"lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -38,7 +38,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
// A simple translation of LHLO reduce operations to a corresponding gpu
|
||||
|
@ -173,7 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
|
|||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>();
|
||||
gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
|
||||
target.addIllegalOp<ReduceOp>();
|
||||
auto func = getFunction();
|
||||
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
|
||||
|
@ -192,5 +192,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() {
|
|||
static PassRegistration<LhloLegalizeToGpu> legalize_pass(
|
||||
"lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
struct StaticMemRefCastOpConverter
|
||||
|
@ -132,5 +132,5 @@ void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
|
|||
*converter, options);
|
||||
}
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
class TestLhloToLLVMPass
|
||||
|
@ -42,7 +42,7 @@ class TestLhloToLLVMPass
|
|||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addIllegalDialect<XlaLhloDialect>();
|
||||
target.addIllegalDialect<LmhloDialect>();
|
||||
|
||||
if (failed(applyFullConversion(m, target, patterns))) {
|
||||
signalPassFailure();
|
||||
|
@ -55,5 +55,5 @@ class TestLhloToLLVMPass
|
|||
static PassRegistration<TestLhloToLLVMPass> legalize_lhlo_pass(
|
||||
"test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM.");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
// Clones and adapts the code in `lhlo_block` that works on buffers and has a
|
||||
|
@ -154,14 +154,14 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
|
|||
return b->create<scf::ParallelOp>(loc, lower, upper, step);
|
||||
}
|
||||
|
||||
// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
|
||||
// Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
|
||||
// The outper `ParallelOp` refers to the parallel loops if there are
|
||||
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
|
||||
// contains the reduction operator.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( {
|
||||
// "lmhlo.reduce"(%buffer, %init_buf, %result) ( {
|
||||
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
// <LHLO ops>
|
||||
// } ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||
|
@ -187,12 +187,12 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
|
|||
// } : f32
|
||||
// scf.yield
|
||||
// }
|
||||
class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||
class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
|
||||
lmhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// TODO(b/137624192) Implement variadic reduce.
|
||||
if (xla_reduce_op.out().size() != 1) return failure();
|
||||
|
@ -226,7 +226,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
|||
// scf.yield
|
||||
// }
|
||||
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
|
||||
xla_lhlo::ReduceOp xla_reduce_op,
|
||||
lmhlo::ReduceOp xla_reduce_op,
|
||||
ConversionPatternRewriter* rewriter) const {
|
||||
auto loc = xla_reduce_op.getLoc();
|
||||
DenseSet<int> reducing_dims;
|
||||
|
@ -314,7 +314,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
|||
// accumulator = reduction_operator(output[O], value)
|
||||
// output[O] = accumulator
|
||||
//
|
||||
// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a
|
||||
// Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
|
||||
// scf::ReduceOp.
|
||||
// The outper `ParallelOp` refers to the parallel loops that traverese output
|
||||
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse
|
||||
|
@ -325,11 +325,11 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
|||
// func @reduce_window(%arg: memref<112x112xf32>,
|
||||
// %init: memref<f32>,
|
||||
// %result: memref<56x56xf32>) {
|
||||
// "xla_lhlo.reduce_window"(%arg, %init, %result) ( {
|
||||
// "lmhlo.reduce_window"(%arg, %init, %result) ( {
|
||||
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
// "xla_lhlo.maximum"(%lhs, %rhs, %res)
|
||||
// "lmhlo.maximum"(%lhs, %rhs, %res)
|
||||
// : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// "lmhlo.terminator"() : () -> ()
|
||||
// }) {
|
||||
// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
// window_dimensions = dense<[3, 3]> : tensor<2xi64>,
|
||||
|
@ -359,12 +359,12 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
|||
// return
|
||||
// }
|
||||
class ReduceWindowOpConverter
|
||||
: public OpConversionPattern<xla_lhlo::ReduceWindowOp> {
|
||||
: public OpConversionPattern<lmhlo::ReduceWindowOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::ReduceWindowOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
|
||||
lmhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
scf::ParallelOp output_loop, window_loop;
|
||||
std::tie(output_loop, window_loop) =
|
||||
|
@ -383,7 +383,7 @@ class ReduceWindowOpConverter
|
|||
private:
|
||||
std::pair<scf::ParallelOp, scf::ParallelOp>
|
||||
CreateParallelLoopsToTraverseOutputAndWindow(
|
||||
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
|
||||
lmhlo::ReduceWindowOp xla_reduce_window_op,
|
||||
ConversionPatternRewriter* rewriter) const {
|
||||
auto loc = xla_reduce_window_op.getLoc();
|
||||
Value init_value =
|
||||
|
@ -415,9 +415,8 @@ class ReduceWindowOpConverter
|
|||
}
|
||||
|
||||
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
|
||||
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
|
||||
scf::ParallelOp output_loop, scf::ParallelOp window_loop,
|
||||
ConversionPatternRewriter* rewriter) const {
|
||||
lmhlo::ReduceWindowOp xla_reduce_window_op, scf::ParallelOp output_loop,
|
||||
scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
|
||||
rewriter->setInsertionPointToStart(window_loop.getBody());
|
||||
auto loc = xla_reduce_window_op.getLoc();
|
||||
|
||||
|
@ -481,12 +480,12 @@ class ReduceWindowOpConverter
|
|||
// initialized_flag = true
|
||||
// output(selected_index) = scatter(output(selected_index), source(S))
|
||||
class SelectAndScatterOpConverter
|
||||
: public OpConversionPattern<xla_lhlo::SelectAndScatterOp> {
|
||||
: public OpConversionPattern<lmhlo::SelectAndScatterOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::SelectAndScatterOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::SelectAndScatterOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
|
||||
lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = s_and_s_op.getLoc();
|
||||
InitializeOutput(s_and_s_op, &rewriter);
|
||||
|
@ -515,7 +514,7 @@ class SelectAndScatterOpConverter
|
|||
}
|
||||
|
||||
private:
|
||||
void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op,
|
||||
void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
|
||||
OpBuilder* b) const {
|
||||
auto loc = s_and_s_op.getLoc();
|
||||
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
|
||||
|
@ -533,7 +532,7 @@ class SelectAndScatterOpConverter
|
|||
SmallVector<Value, 2> window_ivs;
|
||||
scf::ForOp inner_loop;
|
||||
};
|
||||
WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op,
|
||||
WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,
|
||||
scf::ParallelOp loop_over_src,
|
||||
OpBuilder* b) const {
|
||||
auto loc = s_and_s_op.getLoc();
|
||||
|
@ -598,7 +597,7 @@ class SelectAndScatterOpConverter
|
|||
SmallVector<Value, 4> ivs_val_flag_;
|
||||
};
|
||||
|
||||
SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op,
|
||||
SmallVector<Value, 2> SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,
|
||||
scf::ParallelOp loop_over_src,
|
||||
OpBuilder* b) const {
|
||||
auto loc = s_and_s_op.getLoc();
|
||||
|
@ -636,9 +635,10 @@ class SelectAndScatterOpConverter
|
|||
return window_loops.selected_ivs;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> SelectOrInitialize(
|
||||
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> operand_ivs,
|
||||
IterArgs* ivs_val_flag, OpBuilder* b) const {
|
||||
SmallVector<Value, 4> SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,
|
||||
ArrayRef<Value> operand_ivs,
|
||||
IterArgs* ivs_val_flag,
|
||||
OpBuilder* b) const {
|
||||
auto loc = s_and_s_op.getLoc();
|
||||
Value true_i1 = b->create<mlir::ConstantOp>(
|
||||
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
|
||||
|
@ -707,9 +707,9 @@ struct LhloLegalizeToParallelLoops
|
|||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
scf::SCFDialect, XlaLhloDialect>();
|
||||
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
|
||||
xla_lhlo::SelectAndScatterOp>();
|
||||
scf::SCFDialect, LmhloDialect>();
|
||||
target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
|
||||
lmhlo::SelectAndScatterOp>();
|
||||
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
signalPassFailure();
|
||||
|
@ -727,5 +727,5 @@ static PassRegistration<LhloLegalizeToParallelLoops> legalize_lhlo_pass(
|
|||
"lhlo-legalize-to-parallel-loops",
|
||||
"Legalize from LHLO dialect to parallel loops.");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -131,9 +131,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
|||
loc, opResultTypes, args, args_count, results_count, indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace.
|
||||
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
||||
// That method needs to be moved out of there.
|
||||
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
|
||||
Value opResult = lmhlo::XlaOpToStdScalarOp::map<OpTy>(
|
||||
op, bodyResultTypes,
|
||||
llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
|
||||
|
@ -162,8 +162,8 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
|||
// Create two loads from the input.
|
||||
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
||||
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
||||
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
|
||||
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
|
||||
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
|
||||
Value opResult = lmhlo::XlaOpToStdScalarOp::map<LhloOp>(
|
||||
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
||||
&rewriter);
|
||||
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
|
||||
|
@ -173,21 +173,21 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
|||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// xla_lhlo.convolution conversion pattern.
|
||||
// lmhlo.convolution conversion pattern.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Converts xla_lhlo.convolution operation to a linalg.conv op.
|
||||
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
|
||||
/// Converts lmhlo.convolution operation to a linalg.conv op.
|
||||
struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
|
||||
|
||||
// This code has been adapted from IREE's
|
||||
// (https://github.com/google/iree/) mhlo -> linalg conversion.
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ConvOp op, ArrayRef<Value> args,
|
||||
lmhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// Check validity of dimension information.
|
||||
if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers =
|
||||
if (const lmhlo::ConvDimensionNumbers& dimensionNumbers =
|
||||
op.dimension_numbers()) {
|
||||
const int inputSpatialRank =
|
||||
llvm::size(dimensionNumbers.input_spatial_dimensions());
|
||||
|
@ -388,14 +388,14 @@ class HloBroadcastInDimConverter
|
|||
};
|
||||
|
||||
class LhloBroadcastInDimConverter
|
||||
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
|
||||
: public OpConversionPattern<lmhlo::BroadcastInDimOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::BroadcastInDimOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
||||
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
||||
auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
|
||||
auto result_shape = result_type.getShape();
|
||||
|
||||
|
@ -444,9 +444,9 @@ class LhloBroadcastInDimConverter
|
|||
|
||||
// Inserts 'linalg.reshape' if there is a size-1 dim expansion.
|
||||
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
|
||||
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
||||
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
||||
Value operand = operand_adaptor.operand();
|
||||
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
|
||||
auto operand_shape = operand_type.getShape();
|
||||
|
@ -512,7 +512,7 @@ class LhloBroadcastInDimConverter
|
|||
return std::make_pair(operand, broadcast_dims);
|
||||
}
|
||||
|
||||
SmallVector<AffineMap, 2> getIndexingMaps(xla_lhlo::BroadcastInDimOp op,
|
||||
SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
|
||||
ArrayRef<int64_t> broadcastDims,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
MemRefType operandType,
|
||||
|
@ -639,12 +639,12 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|||
}
|
||||
};
|
||||
|
||||
class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
|
||||
class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::IotaOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::IotaOp iotaOp, ArrayRef<Value> args,
|
||||
lmhlo::IotaOp iotaOp, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto resultMemrefType =
|
||||
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
|
||||
|
@ -680,12 +680,12 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
|
|||
}
|
||||
};
|
||||
|
||||
class ConstConverter : public OpConversionPattern<xla_lhlo::ConstOp> {
|
||||
class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::ConstOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ConstOp constOp, ArrayRef<Value> args,
|
||||
lmhlo::ConstOp constOp, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = constOp.getLoc();
|
||||
auto valueAttr = constOp.value().cast<DenseElementsAttr>();
|
||||
|
@ -726,12 +726,12 @@ class ReverseConverter
|
|||
}
|
||||
};
|
||||
|
||||
class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
|
||||
class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::SliceOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::SliceOp sliceOp, ArrayRef<Value> args,
|
||||
lmhlo::SliceOp sliceOp, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = sliceOp.getLoc();
|
||||
auto argType =
|
||||
|
@ -763,50 +763,50 @@ class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
|
|||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
|
||||
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
||||
ConstConverter,
|
||||
ConvToLinalgConverter,
|
||||
IotaConverter,
|
||||
LhloBroadcastInDimConverter,
|
||||
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::AddOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::CeilOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::CompareOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ComplexOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ConvertOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::AndOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::CeilOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::CompareOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ConvertOp>,
|
||||
// TODO(ataei): Remove this pattern, CopyOp is folded away.
|
||||
PointwiseToLinalgConverter<xla_lhlo::CopyOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::CosOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::DivOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ExpOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ImagOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::LogOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::MaxOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::MinOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::MulOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::NegOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::RealOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::RemOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SignOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SinOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
|
||||
ReshapeOpConverter<xla_lhlo::ReshapeOp>,
|
||||
ReverseConverter<xla_lhlo::ReverseOp>,
|
||||
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::CopyOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::CosOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::DivOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SelectOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SignOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SinOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
||||
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
||||
ReverseConverter<lmhlo::ReverseOp>,
|
||||
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
|
||||
SliceConverter
|
||||
>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// Converts LHLO ops to Linalg generic.
|
||||
// Sample result for xla_lhlo::AddOp.
|
||||
// Sample result for lmhlo::AddOp.
|
||||
//
|
||||
// "xla_lhlo.add"(%arg1, %arg2, %out) :
|
||||
// "lmhlo.add"(%arg1, %arg2, %out) :
|
||||
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
//
|
||||
// will be converted to
|
||||
|
@ -854,14 +854,14 @@ struct HloLegalizeToLinalg
|
|||
|
||||
} // namespace
|
||||
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
|
||||
return absl::make_unique<LhloLegalizeToLinalg>();
|
||||
}
|
||||
|
||||
static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
|
||||
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
|
||||
namespace mhlo {
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -18,10 +18,10 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
|||
return %arg0 : tensor<4xf32>
|
||||
}
|
||||
// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
||||
// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
|
||||
// PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
|
||||
// PRE-NEXT: return
|
||||
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// ESC-NOT: "xla_lhlo.copy"
|
||||
// ESC-NOT: "lmhlo.copy"
|
||||
// ESC-NEXT: return %[[ARG0]]
|
||||
|
||||
// -----
|
||||
|
@ -38,20 +38,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
|||
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||
// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
|
||||
// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
||||
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
||||
// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
|
||||
// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
|
||||
// PRE-NEXT: return
|
||||
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
|
||||
|
@ -67,14 +67,14 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
|||
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
|
||||
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
// BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
|
||||
// BOTH-NEXT: return
|
||||
|
@ -88,7 +88,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.copy"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -100,7 +100,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -112,7 +112,7 @@ func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.log"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.log"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
|||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
|
|||
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
|
||||
{comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
// BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
tensor_store %tensor_result, %result : memref<2x2xi1>
|
||||
return
|
||||
}
|
||||
|
@ -154,7 +154,7 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
|||
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
|
||||
{broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<5xf32>) -> tensor<10x5xf32>
|
||||
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
tensor_store %tensor_result, %result : memref<10x5xf32>
|
||||
return
|
||||
}
|
||||
|
@ -205,12 +205,12 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
|||
// BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
||||
// BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
|
||||
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast
|
||||
// BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
|
||||
// BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
||||
// BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
|
||||
|
||||
// BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
// BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
|
||||
|
||||
|
@ -229,7 +229,7 @@ func @complex(%real: memref<2x2xf32>,
|
|||
%tensor_imag = tensor_load %imag : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
|
||||
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
|
||||
return
|
||||
}
|
||||
|
@ -241,7 +241,7 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "mhlo.real"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -253,7 +253,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "mhlo.imag"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -264,7 +264,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
|||
func @iota(%result: memref<10xi32>) {
|
||||
%tensor_result = "mhlo.iota"()
|
||||
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
|
||||
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
// BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
tensor_store %tensor_result, %result : memref<10xi32>
|
||||
return
|
||||
}
|
||||
|
@ -276,7 +276,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.abs"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -288,7 +288,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.ceil"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -300,7 +300,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.convert"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH-NOT: tensor_store
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
|
@ -313,7 +313,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.cosine"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -325,7 +325,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.negate"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -337,7 +337,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -349,7 +349,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.sign"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -361,7 +361,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.sqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -373,7 +373,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.tanh"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -386,7 +386,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
|||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
@ -412,7 +412,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
|||
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
// BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -437,7 +437,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
|||
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
// BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -448,10 +448,10 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
|||
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
||||
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
|
||||
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
%dot = "mhlo.dot"(%arg0, %arg0)
|
||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
||||
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
||||
// ESC: return %[[ALLOC]]
|
||||
return %dot : tensor<1024x1024xf32>
|
||||
}
|
||||
|
@ -462,7 +462,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
|||
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
// BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
|
||||
// BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// BOTH-SAME: padding = dense<[
|
||||
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
|
||||
|
|
|
@ -3,10 +3,10 @@
|
|||
// CHECK-LABEL: func @remove_simple
|
||||
func @remove_simple(%arg0: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -14,9 +14,9 @@ func @remove_simple(%arg0: memref<2x2xf32>) {
|
|||
// CHECK-LABEL: func @remove_without_dealloc
|
||||
func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -24,22 +24,22 @@ func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
|
|||
// CHECK-LABEL: func @replace_dependency
|
||||
func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @keep_copies
|
||||
func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
// CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -50,14 +50,14 @@ func @must_not_be_removed(%arg0: memref<2x2xf32>,
|
|||
%arg2: memref<2x2xf32>) {
|
||||
// CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32>
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -67,13 +67,13 @@ func @must_be_removed_first(%arg0: memref<2x2xf32>,
|
|||
%arg1: memref<2x2xf32>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -83,11 +83,11 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>,
|
|||
%arg1: memref<2x2xf32>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
|
|
@ -10,18 +10,18 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
|||
%src: memref<56x56xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<112x112xf32>) {
|
||||
"xla_lhlo.select_and_scatter"(%arg, %src, %init, %result) ( {
|
||||
"lmhlo.select_and_scatter"(%arg, %src, %init, %result) ( {
|
||||
// select
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %pred: memref<i1>):
|
||||
"xla_lhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} :
|
||||
"lmhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} :
|
||||
(memref<f32>, memref<f32>, memref<i1>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}, {
|
||||
// scatter
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %out: memref<f32>):
|
||||
"xla_lhlo.add"(%lhs, %rhs, %out) :
|
||||
"lmhlo.add"(%lhs, %rhs, %out) :
|
||||
(memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}) {
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
|
||||
|
@ -29,7 +29,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
|||
} : (memref<112x112xf32>,
|
||||
memref<56x56xf32>,
|
||||
memref<f32>, memref<112x112xf32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
// CHECK-LABEL: func @select_and_scatter(
|
||||
// CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>,
|
||||
|
@ -121,7 +121,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
|||
// CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
|
||||
|
||||
// Compute PRED.
|
||||
// CHECK: "xla_lhlo.compare"(
|
||||
// CHECK: "lmhlo.compare"(
|
||||
// CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
|
||||
// CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1>
|
||||
|
||||
|
@ -182,7 +182,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
|||
// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
|
||||
|
||||
// Compute scatter value.
|
||||
// CHECK: "xla_lhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
|
||||
// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
|
||||
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32>
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
|
|||
// CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
|
||||
// CHECK: return
|
||||
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
|
||||
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
|
||||
(memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
|
|||
func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: addf %{{.*}}, %{{.*}} : f32
|
||||
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
|
||||
"lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
|||
func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: addi %{{.*}}, %{{.*}} : i32
|
||||
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
|
||||
"lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
|||
func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: and %{{.*}}, %{{.*}} : i32
|
||||
"xla_lhlo.and"(%lhs, %rhs, %result) {name = "and.1"}
|
||||
"lmhlo.and"(%lhs, %rhs, %result) {name = "and.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
|||
func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: divf %{{.*}}, %{{.*}} : f32
|
||||
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
|
||||
"lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
|||
func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: divi_signed %{{.*}}, %{{.*}} : i32
|
||||
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
|
||||
"lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
|||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32
|
||||
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
|
||||
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
|
||||
"lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
|||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32
|
||||
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
|
||||
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
|
||||
"lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -92,7 +92,7 @@ func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
|||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32
|
||||
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
|
||||
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
|
||||
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
|||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32
|
||||
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
|
||||
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
|
||||
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -112,7 +112,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
|||
func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: mulf %{{.*}}, %{{.*}} : f32
|
||||
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
|
||||
"lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
|||
func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: muli %{{.*}}, %{{.*}} : i32
|
||||
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
|
||||
"lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -131,7 +131,7 @@ func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
|||
func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: subf %{{.*}}, %{{.*}} : f32
|
||||
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
|
||||
"lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
|||
func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: subi %{{.*}}, %{{.*}} : i32
|
||||
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
|
||||
"lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -158,7 +158,7 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs:
|
|||
// CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32
|
||||
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
|
||||
// CHECK: return
|
||||
"xla_lhlo.dot"(%lhs, %rhs, %result) :
|
||||
"lmhlo.dot"(%lhs, %rhs, %result) :
|
||||
(memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -175,7 +175,7 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
|
|||
// CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32
|
||||
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
|
||||
// CHECK: return
|
||||
"xla_lhlo.dot"(%lhs, %rhs, %result) :
|
||||
"lmhlo.dot"(%lhs, %rhs, %result) :
|
||||
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -3,11 +3,11 @@
|
|||
func @reduce(%arg: memref<100x10xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<100xf32>) {
|
||||
"xla_lhlo.reduce"(%arg, %init, %result) ( {
|
||||
"lmhlo.reduce"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"xla_lhlo.add"(%lhs, %rhs, %res)
|
||||
"lmhlo.add"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
|
||||
return
|
||||
|
@ -25,7 +25,7 @@ func @reduce(%arg: memref<100x10xf32>,
|
|||
// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
|
||||
// CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref<f32, #map0>
|
||||
// CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref<f32, #map0>
|
||||
// CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
|
||||
// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
|
||||
// CHECK: }
|
||||
// CHECK: gpu.terminator
|
||||
// CHECK: }
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
// CHECK-LABEL: func @element_wise
|
||||
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
"xla_lhlo.add"(%lhs, %rhs, %result)
|
||||
"lmhlo.add"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
|||
func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>,
|
||||
%rhs: memref<?x?xf32>,
|
||||
%result: memref<?x?xf32>) {
|
||||
"xla_lhlo.add"(%lhs, %rhs, %result)
|
||||
"lmhlo.add"(%lhs, %rhs, %result)
|
||||
: (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>,
|
|||
// CHECK-LABEL: func @element_wise_scalar
|
||||
func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
|
||||
%result: memref<f32>) {
|
||||
"xla_lhlo.add"(%lhs, %rhs, %result)
|
||||
"lmhlo.add"(%lhs, %rhs, %result)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
|
|||
// CHECK-LABEL: func @minf
|
||||
func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
"xla_lhlo.minimum"(%lhs, %rhs, %result)
|
||||
"lmhlo.minimum"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
|||
// CHECK-LABEL: func @maxi
|
||||
func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||
%result: memref<2x2xi32>) {
|
||||
"xla_lhlo.maximum"(%lhs, %rhs, %result)
|
||||
"lmhlo.maximum"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -78,7 +78,7 @@ func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
|||
// CHECK-LABEL: func @and
|
||||
func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||
%result: memref<2x2xi32>) {
|
||||
"xla_lhlo.and"(%lhs, %rhs, %result)
|
||||
"lmhlo.and"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -91,7 +91,7 @@ func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
|||
|
||||
// CHECK-LABEL: func @exp
|
||||
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.exponential"(%input, %result)
|
||||
"lmhlo.exponential"(%input, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @log
|
||||
func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -116,7 +116,7 @@ func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @copy
|
||||
func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
|
||||
"xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
|
||||
"lmhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -128,7 +128,7 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
|
|||
// CHECK-LABEL: func @float_cmp
|
||||
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xi1>) {
|
||||
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
|
||||
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -142,7 +142,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
|||
// CHECK-LABEL: func @int_cmp
|
||||
func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||
%result: memref<2x2xi1>) {
|
||||
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
|
||||
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
|
||||
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -156,7 +156,7 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
|||
// CHECK-LABEL: func @select
|
||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.select"(%pred, %lhs, %rhs, %result)
|
||||
"lmhlo.select"(%pred, %lhs, %rhs, %result)
|
||||
: (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -170,7 +170,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
|||
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @iota
|
||||
func @iota(%out: memref<7x10xf32>) {
|
||||
"xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
|
||||
"lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.indexed_generic
|
||||
|
@ -186,7 +186,7 @@ func @iota(%out: memref<7x10xf32>) {
|
|||
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @broadcast_scalar
|
||||
func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
|
||||
"xla_lhlo.broadcast"(%operand, %result) {
|
||||
"lmhlo.broadcast"(%operand, %result) {
|
||||
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
|
||||
} : (memref<f32>, memref<4x2x1xf32>) -> ()
|
||||
return
|
||||
|
@ -203,7 +203,7 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
|
|||
// CHECK-LABEL: func @broadcast
|
||||
func @broadcast(%operand: memref<4x?x16xf32>,
|
||||
%result: memref<4x2x1x4x?x16xf32>) {
|
||||
"xla_lhlo.broadcast"(%operand, %result) {
|
||||
"lmhlo.broadcast"(%operand, %result) {
|
||||
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
|
||||
} : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> ()
|
||||
return
|
||||
|
@ -220,7 +220,7 @@ func @broadcast(%operand: memref<4x?x16xf32>,
|
|||
// CHECK-LABEL: func @dynamic_broadcast_in_dim
|
||||
func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
|
||||
%result: memref<?x?x?x?x?xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>
|
||||
} : (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
|
||||
return
|
||||
|
@ -237,7 +237,7 @@ func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
|
|||
// CHECK-LABEL: func @static_broadcast_in_dim_no_expansion
|
||||
func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
|
||||
%result: memref<5x10xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[0]> : tensor<1xi64>
|
||||
} : (memref<5xf32>, memref<5x10xf32>) -> ()
|
||||
return
|
||||
|
@ -255,7 +255,7 @@ func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
|
|||
// CHECK-LABEL: func @static_broadcast_in_dim_expansion
|
||||
func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
|
||||
%result: memref<5x10x100xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[2, 0]> : tensor<2xi64>
|
||||
} : (memref<1x5xf32>, memref<5x10x100xf32>) -> ()
|
||||
return
|
||||
|
@ -274,7 +274,7 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
|
|||
// CHECK-LABEL: func @static_broadcast_in_dim_scalar
|
||||
func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
|
||||
%result: memref<5x10xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[]> : tensor<0xi64>
|
||||
} : (memref<f32>, memref<5x10xf32>) -> ()
|
||||
return
|
||||
|
@ -291,7 +291,7 @@ func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
|
|||
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one
|
||||
func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
|
||||
%result: memref<1x5xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[0]> : tensor<1xi64>
|
||||
} : (memref<1xf32>, memref<1x5xf32>) -> ()
|
||||
return
|
||||
|
@ -307,7 +307,7 @@ func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
|
|||
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many
|
||||
func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
|
||||
%result: memref<5x5xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[1]> : tensor<1xi64>
|
||||
} : (memref<1xf32>, memref<5x5xf32>) -> ()
|
||||
return
|
||||
|
@ -323,7 +323,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
|
|||
|
||||
// CHECK-LABEL: func @constant
|
||||
func @constant(%value: memref<i32>) {
|
||||
"xla_lhlo.constant"(%value) {
|
||||
"lmhlo.constant"(%value) {
|
||||
value = dense<10> : tensor<i32>
|
||||
} : (memref<i32>) -> ()
|
||||
return
|
||||
|
@ -335,7 +335,7 @@ func @constant(%value: memref<i32>) {
|
|||
|
||||
// CHECK-LABEL: func @absf
|
||||
func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -348,7 +348,7 @@ func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// CHECK-LABEL: func @absi
|
||||
func @absi(%input: memref<2x2xi32>,
|
||||
%result: memref<2x2xi32>) {
|
||||
"xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -364,7 +364,7 @@ func @absi(%input: memref<2x2xi32>,
|
|||
|
||||
// CHECK-LABEL: func @ceil
|
||||
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -376,7 +376,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @convert_i32_to_f32
|
||||
func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -389,7 +389,7 @@ func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
|
|||
// CHECK-LABEL: func @convert_i16_to_i32
|
||||
func @convert_i16_to_i32(%input: memref<2x2xi16>,
|
||||
%result: memref<2x2xi32>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -401,7 +401,7 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>,
|
|||
|
||||
// CHECK-LABEL: func @convert_i32_to_i16
|
||||
func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -413,7 +413,7 @@ func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) {
|
|||
|
||||
// CHECK-LABEL: func @convert_f32_to_f64
|
||||
func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -425,7 +425,7 @@ func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) {
|
|||
|
||||
// CHECK-LABEL: func @convert_f64_to_f32
|
||||
func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -437,7 +437,7 @@ func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @convert_i32_to_i32
|
||||
func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -448,7 +448,7 @@ func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
|||
|
||||
// CHECK-LABEL: func @convert_f32_to_f32
|
||||
func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -459,7 +459,7 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @convert_f32_to_i32
|
||||
func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
|
||||
"xla_lhlo.convert"(%input, %result)
|
||||
"lmhlo.convert"(%input, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -472,7 +472,7 @@ func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
|
|||
|
||||
// CHECK-LABEL: func @cos
|
||||
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -485,7 +485,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// CHECK-LABEL: func @sin
|
||||
func @sin(%input: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
"xla_lhlo.sine"(%input, %result)
|
||||
"lmhlo.sine"(%input, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -498,7 +498,7 @@ func @sin(%input: memref<2x2xf32>,
|
|||
|
||||
// CHECK-LABEL: func @negf
|
||||
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -510,7 +510,7 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @negi
|
||||
func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
"xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -524,7 +524,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
|||
// CHECK-LABEL: func @rem
|
||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
"xla_lhlo.remainder"(%lhs, %rhs, %result)
|
||||
"lmhlo.remainder"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -537,7 +537,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
|||
|
||||
// CHECK-LABEL: func @rsqrt
|
||||
func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -549,7 +549,7 @@ func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @sign
|
||||
func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -562,7 +562,7 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @sqrt
|
||||
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -574,7 +574,7 @@ func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @tanh
|
||||
func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -588,7 +588,7 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
func @complex(%real: memref<2x2xf32>,
|
||||
%imag: memref<2x2xf32>,
|
||||
%cplx: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.complex"(%real, %imag, %cplx)
|
||||
"lmhlo.complex"(%real, %imag, %cplx)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -602,7 +602,7 @@ func @complex(%real: memref<2x2xf32>,
|
|||
// CHECK-LABEL: func @real
|
||||
func @real(%cplx: memref<2x2xcomplex<f32>>,
|
||||
%real: memref<2x2xf32>) {
|
||||
"xla_lhlo.real"(%cplx, %real)
|
||||
"lmhlo.real"(%cplx, %real)
|
||||
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -616,7 +616,7 @@ func @real(%cplx: memref<2x2xcomplex<f32>>,
|
|||
// CHECK-LABEL: func @imag
|
||||
func @imag(%cplx: memref<2x2xcomplex<f32>>,
|
||||
%imag: memref<2x2xf32>) {
|
||||
"xla_lhlo.imag"(%cplx, %imag)
|
||||
"lmhlo.imag"(%cplx, %imag)
|
||||
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -629,7 +629,7 @@ func @imag(%cplx: memref<2x2xcomplex<f32>>,
|
|||
|
||||
// CHECK: func @slice(%[[IN:.*]]: memref<?x?xf32>, %[[OUT:.*]]: memref<?x?xf32>)
|
||||
func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
|
||||
"xla_lhlo.slice"(%operand, %result) {
|
||||
"lmhlo.slice"(%operand, %result) {
|
||||
start_indices = dense<[0,1]> : tensor<2xi64>,
|
||||
limit_indices = dense<[2,3]> : tensor<2xi64>,
|
||||
strides = dense<[1,1]> : tensor<2xi64>
|
||||
|
@ -653,7 +653,7 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
|
|||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK-LABEL: func @reshape_3D_2D
|
||||
func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
|
||||
"xla_lhlo.reshape"(%arg0, %arg1)
|
||||
"lmhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<12x1x42xi32>, memref<12x42xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -666,7 +666,7 @@ func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
|
|||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_4D_2D
|
||||
func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
|
||||
"xla_lhlo.reshape"(%arg0, %arg1)
|
||||
"lmhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<12x42x1x1xi32>, memref<12x42xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -679,7 +679,7 @@ func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
|
|||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_2D_4D
|
||||
func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
|
||||
"xla_lhlo.reshape"(%arg0, %arg1)
|
||||
"lmhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<12x42xi32>, memref<12x1x42x1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -692,7 +692,7 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
|
|||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @reverse
|
||||
func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
|
||||
"xla_lhlo.reverse"(%arg0, %arg1) {
|
||||
"lmhlo.reverse"(%arg0, %arg1) {
|
||||
dimensions = dense<1> : tensor<1xi64>
|
||||
} : (memref<2x3xf32>, memref<2x3xf32>) -> ()
|
||||
return
|
||||
|
@ -710,15 +710,15 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m
|
|||
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-SAME: strides = [2, 1]}
|
||||
// With all atributes explicitly specified.
|
||||
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
"lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
|
||||
// Dilation left unspecified, sets default dilation since linalg expects it.
|
||||
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
|
||||
// CHECK-SAME: dilations = [1, 1]
|
||||
// Padding is not set if it's zero.
|
||||
// CHECK-NOT: padding
|
||||
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
"lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
|
||||
"xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
// CHECK-LABEL: func @static_memref_cast
|
||||
func @static_memref_cast(%buf : memref<10x1x5xf32>) {
|
||||
%0 = xla_lhlo.static_memref_cast %buf
|
||||
%0 = lmhlo.static_memref_cast %buf
|
||||
: memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]>
|
||||
return
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
|
|||
%size_Y = constant 50 : index
|
||||
%stride_X = constant 1 : index
|
||||
%stride_Y = constant 0 : index
|
||||
%0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
|
||||
%0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
|
||||
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -3,11 +3,11 @@
|
|||
func @reduce(%arg: memref<100x10x5xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<100x5xf32>) {
|
||||
"xla_lhlo.reduce"(%arg, %init, %result) ( {
|
||||
"lmhlo.reduce"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"xla_lhlo.add"(%lhs, %rhs, %res)
|
||||
"lmhlo.add"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||
: (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
|
||||
return
|
||||
|
@ -35,7 +35,7 @@ func @reduce(%arg: memref<100x10x5xf32>,
|
|||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
|
@ -49,11 +49,11 @@ func @reduce(%arg: memref<100x10x5xf32>,
|
|||
func @reduce_no_outer_loop(%arg: memref<100xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<1xf32>) {
|
||||
"xla_lhlo.reduce"(%arg, %init, %result) ( {
|
||||
"lmhlo.reduce"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"xla_lhlo.add"(%lhs, %rhs, %res)
|
||||
"lmhlo.add"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[0]> : tensor<1xi64>}
|
||||
: (memref<100xf32>, memref<f32>, memref<1xf32>) -> ()
|
||||
return
|
||||
|
@ -76,7 +76,7 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
|
|||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]]
|
||||
// CHECK: }
|
||||
|
@ -88,11 +88,11 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
|
|||
func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<?x?xf32>) {
|
||||
"xla_lhlo.reduce"(%arg, %init, %result) ( {
|
||||
"lmhlo.reduce"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"xla_lhlo.add"(%lhs, %rhs, %res)
|
||||
"lmhlo.add"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||
: (memref<?x?x?xf32>, memref<f32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
|
@ -121,7 +121,7 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
|||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
|
@ -135,11 +135,11 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
|||
func @reduce_window(%arg: memref<112x112xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<56x56xf32>) {
|
||||
"xla_lhlo.reduce_window"(%arg, %init, %result) ( {
|
||||
"lmhlo.reduce_window"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"xla_lhlo.maximum"(%lhs, %rhs, %res)
|
||||
"lmhlo.maximum"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}) {
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
|
||||
|
@ -189,7 +189,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
|||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
// CHECK-LABEL: func @ceil
|
||||
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point values}}
|
||||
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,7 @@ func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
|||
|
||||
// CHECK-LABEL: func @cos
|
||||
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @cos
|
||||
func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
"lmhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -36,7 +36,7 @@ func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
|||
|
||||
func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -44,7 +44,7 @@ func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
|||
|
||||
// CHECK-LABEL: func @sin
|
||||
func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @sin
|
||||
func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
"lmhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -60,7 +60,7 @@ func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
|||
|
||||
func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -68,7 +68,7 @@ func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
|||
|
||||
// CHECK-LABEL: func @add_memrefs
|
||||
func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -76,7 +76,7 @@ func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1
|
|||
|
||||
// CHECK-LABEL: func @abs_memref
|
||||
func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -84,7 +84,7 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @convert_memref
|
||||
func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
|
||||
"lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -92,7 +92,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
|
|||
|
||||
func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
|
||||
"lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -100,7 +100,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @exp
|
||||
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -108,7 +108,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @exp
|
||||
func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
"lmhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -116,7 +116,7 @@ func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
|||
|
||||
func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -124,7 +124,7 @@ func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
|||
|
||||
// CHECK-LABEL: func @log_memref
|
||||
func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -132,7 +132,7 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @log_memref
|
||||
func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
"lmhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -140,7 +140,7 @@ func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) ->
|
|||
|
||||
func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -148,7 +148,7 @@ func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @neg_memref
|
||||
func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -156,7 +156,7 @@ func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @rsqrt_memref
|
||||
func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -164,7 +164,7 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @rsqrt_memref
|
||||
func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
"lmhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -172,7 +172,7 @@ func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>)
|
|||
|
||||
func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -180,7 +180,7 @@ func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @sqrt_memref
|
||||
func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -188,7 +188,7 @@ func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @sqrt_memref
|
||||
func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
"lmhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -196,7 +196,7 @@ func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -
|
|||
|
||||
func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -204,7 +204,7 @@ func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @sign_memref
|
||||
func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -212,7 +212,7 @@ func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @tanh_memref
|
||||
func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -220,7 +220,7 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @tanh_memref
|
||||
func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
"lmhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -228,15 +228,15 @@ func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -
|
|||
|
||||
func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
|
||||
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
|
||||
"xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
|
||||
// expected-error@+1{{'lmhlo.tanh' op requires all operands to have the same type}}
|
||||
"lmhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -244,7 +244,7 @@ func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @add_memref
|
||||
func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -252,7 +252,7 @@ func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
|||
|
||||
// CHECK-LABEL: func @div_memref
|
||||
func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -260,7 +260,7 @@ func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
|||
|
||||
// CHECK-LABEL: func @max_memref
|
||||
func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -268,7 +268,7 @@ func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
|||
|
||||
// CHECK-LABEL: func @min_memref
|
||||
func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -276,7 +276,7 @@ func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
|||
|
||||
// CHECK-LABEL: func @mul_memref
|
||||
func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -284,7 +284,7 @@ func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
|||
|
||||
// CHECK-LABEL: func @sub_memref
|
||||
func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -292,7 +292,7 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
|||
|
||||
// CHECK-LABEL: func @and_memref
|
||||
func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -300,7 +300,7 @@ func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32
|
|||
|
||||
// CHECK-LABEL: func @and_memref
|
||||
func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
|
||||
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -308,7 +308,7 @@ func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>)
|
|||
|
||||
func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -316,7 +316,7 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
|||
|
||||
// CHECK-LABEL: func @or_memref
|
||||
func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -324,7 +324,7 @@ func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>
|
|||
|
||||
// CHECK-LABEL: func @or_memref
|
||||
func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
|
||||
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -332,7 +332,7 @@ func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -
|
|||
|
||||
func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -340,7 +340,7 @@ func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>
|
|||
|
||||
// CHECK-LABEL: func @xor_memref
|
||||
func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -348,7 +348,7 @@ func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32
|
|||
|
||||
// CHECK-LABEL: func @xor_memref
|
||||
func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
|
||||
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -356,7 +356,7 @@ func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>)
|
|||
|
||||
func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -364,7 +364,7 @@ func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
|||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_memref
|
||||
func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
|
||||
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
|
||||
"lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -372,7 +372,7 @@ func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -
|
|||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref
|
||||
func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi32>) -> () {
|
||||
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
|
||||
"lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -381,10 +381,10 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi
|
|||
|
||||
// CHECK-LABEL: func @reduce_memref
|
||||
func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.reduce"(%input, %init, %out) ( {
|
||||
"lmhlo.reduce"(%input, %init, %out) ( {
|
||||
^bb0(%arg1: memref<f32>, %arg2: memref<f32>, %result: memref<f32>):
|
||||
"xla_lhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref<f32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -393,14 +393,14 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
|
|||
|
||||
// CHECK-LABEL: func @fusion_memref
|
||||
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.fusion"() ( {
|
||||
"lmhlo.fusion"() ( {
|
||||
%0 = tensor_load %input1 : memref<10xf32>
|
||||
%1 = tensor_load %input2 : memref<10xf32>
|
||||
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%3 = tensor_load %input3 : memref<10xf32>
|
||||
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
tensor_store %4, %out : memref<10xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) : () -> ()
|
||||
return
|
||||
}
|
||||
|
@ -409,18 +409,18 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
|
|||
|
||||
// CHECK-LABEL: func @case_memref
|
||||
func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
|
||||
"xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
|
||||
"lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"xla_lhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}, {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"xla_lhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}, {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"xla_lhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>}
|
||||
: (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
|
@ -430,7 +430,7 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr
|
|||
// -----
|
||||
|
||||
func @static_memref_cast(%in: memref<10x1xf32>) {
|
||||
%out = xla_lhlo.static_memref_cast %in
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]>
|
||||
return
|
||||
}
|
||||
|
@ -440,7 +440,7 @@ func @static_memref_cast(%in: memref<10x1xf32>) {
|
|||
|
||||
func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
|
||||
// expected-error @+1 {{operand must have static shape}}
|
||||
%out = xla_lhlo.static_memref_cast %in
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]>
|
||||
return
|
||||
}
|
||||
|
@ -449,7 +449,7 @@ func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
|
|||
|
||||
func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
|
||||
// expected-error @+1 {{result must have static shape}}
|
||||
%out = xla_lhlo.static_memref_cast %in
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
|
@ -459,7 +459,7 @@ func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
|
|||
func @dynamic_memref_cast(%in: memref<?xf32>) {
|
||||
%size = constant 10 : index
|
||||
%step = constant 1 : index
|
||||
%out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
: memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]>
|
||||
return
|
||||
}
|
||||
|
@ -471,7 +471,7 @@ func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
|
|||
// expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}}
|
||||
%size = constant 10 : index
|
||||
%step = constant 1 : index
|
||||
%out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
|
@ -483,19 +483,19 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
|
|||
// CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>,
|
||||
// CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref<?xi32>
|
||||
|
||||
// CHECK-NEXT: [[DYN_VEC:%.*]] = xla_lhlo.reshape_memref_cast [[UNRANKED]]
|
||||
// CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]]
|
||||
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
%dyn_vec = xla_lhlo.reshape_memref_cast %unranked(%shape1)
|
||||
%dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
|
||||
// CHECK-NEXT: [[DYN_MAT:%.*]] = xla_lhlo.reshape_memref_cast [[DYN_VEC]]
|
||||
// CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]]
|
||||
// CHECK-SAME: : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
|
||||
%dyn_mat = xla_lhlo.reshape_memref_cast %dyn_vec(%shape2)
|
||||
%dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2)
|
||||
: (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
|
||||
|
||||
// CHECK-NEXT: {{%.*}} = xla_lhlo.reshape_memref_cast [[DYN_MAT]]
|
||||
// CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]]
|
||||
// CHECK-SAME: : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
%new_unranked = xla_lhlo.reshape_memref_cast %dyn_mat(%shape3)
|
||||
%new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3)
|
||||
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
return
|
||||
}
|
||||
|
@ -505,7 +505,7 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
|
|||
func @reshape_memref_cast_element_type_mismatch(
|
||||
%buf: memref<*xf32>, %shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{element types of source and destination memref types should be the same}}
|
||||
xla_lhlo.reshape_memref_cast %buf(%shape)
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?xi32>
|
||||
}
|
||||
|
||||
|
@ -514,7 +514,7 @@ func @reshape_memref_cast_element_type_mismatch(
|
|||
func @reshape_memref_cast_dst_ranked_shape_unranked(
|
||||
%buf: memref<*xf32>, %shape: memref<?xi32>) {
|
||||
// expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}}
|
||||
xla_lhlo.reshape_memref_cast %buf(%shape)
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<?xi32>) -> memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
@ -524,7 +524,7 @@ func @reshape_memref_cast_dst_ranked_shape_unranked(
|
|||
func @reshape_memref_cast_dst_shape_rank_mismatch(
|
||||
%buf: memref<*xf32>, %shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{length of shape operand differs from the result's memref rank}}
|
||||
xla_lhlo.reshape_memref_cast %buf(%shape)
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
@ -535,7 +535,7 @@ func @reshape_memref_cast_affine_map_is_not_identity(
|
|||
%buf: memref<4x4xf32, offset: 0, strides: [3, 2]>,
|
||||
%shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{operand memref type should have identity affine map}}
|
||||
xla_lhlo.reshape_memref_cast %buf(%shape)
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>)
|
||||
-> memref<8xf32>
|
||||
return
|
||||
|
@ -545,7 +545,7 @@ func @reshape_memref_cast_affine_map_is_not_identity(
|
|||
|
||||
// CHECK-LABEL: func @atan2_memrefs
|
||||
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -553,7 +553,7 @@ func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref
|
|||
|
||||
// CHECK-LABEL: func @atan2_memrefs
|
||||
func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -561,7 +561,7 @@ func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>
|
|||
|
||||
func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -569,7 +569,7 @@ func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref
|
|||
|
||||
// CHECK-LABEL: func @bitcast_convert_memrefs
|
||||
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
|
||||
"lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -577,7 +577,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) ->
|
|||
|
||||
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () {
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
|
||||
"lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -585,7 +585,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) ->
|
|||
|
||||
// CHECK-LABEL: func @clz_memrefs
|
||||
func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -593,7 +593,7 @@ func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @expm1_memrefs
|
||||
func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -601,7 +601,7 @@ func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @expm1_memrefs
|
||||
func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
"lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -609,7 +609,7 @@ func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f3
|
|||
|
||||
// CHECK-LABEL: func @floor_memrefs
|
||||
func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -617,7 +617,7 @@ func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
|||
|
||||
func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point values}}
|
||||
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -625,7 +625,7 @@ func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @imag_memrefs
|
||||
func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
|
||||
"lmhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -633,7 +633,7 @@ func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> ()
|
|||
|
||||
func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error@+1{{must be memref of complex-type values}}
|
||||
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -641,7 +641,7 @@ func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @real_memrefs
|
||||
func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
|
||||
"lmhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -649,7 +649,7 @@ func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> ()
|
|||
|
||||
func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error@+1{{must be memref of complex-type values}}
|
||||
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -657,7 +657,7 @@ func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @is_finite_memrefs
|
||||
func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
|
||||
"xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
|
||||
"lmhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -665,7 +665,7 @@ func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @log1p_memrefs
|
||||
func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -673,7 +673,7 @@ func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @log1p_memrefs
|
||||
func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
"lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -681,7 +681,7 @@ func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f3
|
|||
|
||||
func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -689,7 +689,7 @@ func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @not_memrefs
|
||||
func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -697,7 +697,7 @@ func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @not_memrefs
|
||||
func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
|
||||
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
|
||||
"lmhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -705,7 +705,7 @@ func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
|
|||
|
||||
func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -713,7 +713,7 @@ func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @popcnt_memrefs
|
||||
func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -721,7 +721,7 @@ func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
|||
|
||||
func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -729,7 +729,7 @@ func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @reduce_precision_memrefs
|
||||
func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -737,7 +737,7 @@ func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) ->
|
|||
|
||||
// CHECK-LABEL: func @round_memrefs
|
||||
func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -745,7 +745,7 @@ func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
|||
|
||||
func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point values}}
|
||||
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -753,7 +753,7 @@ func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @shift_left_memrefs
|
||||
func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -761,7 +761,7 @@ func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: m
|
|||
|
||||
func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -769,7 +769,7 @@ func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: m
|
|||
|
||||
// CHECK-LABEL: func @shift_right_arithmetic_memrefs
|
||||
func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -777,7 +777,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>,
|
|||
|
||||
func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -785,7 +785,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>,
|
|||
|
||||
// CHECK-LABEL: func @shift_right_logical_memrefs
|
||||
func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -793,7 +793,7 @@ func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %a
|
|||
|
||||
func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -801,14 +801,14 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
|
|||
|
||||
// CHECK-LABEL: func @all_reduce_memrefs
|
||||
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
"lmhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
|
||||
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
"lmhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%max) : (tensor<f32>) -> ()
|
||||
|
@ -826,11 +826,11 @@ func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> ()
|
|||
|
||||
// CHECK-LABEL: func @collective_permute_memrefs
|
||||
func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
|
||||
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
|
||||
"lmhlo.collective_permute"(%arg0, %arg_out) {
|
||||
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
|
||||
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
|
||||
|
||||
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
|
||||
"lmhlo.collective_permute"(%arg0, %arg_out) {
|
||||
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
|
||||
channel_id = { handle = 5 : i64, type = 2 : i64 }
|
||||
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
|
||||
|
@ -841,7 +841,7 @@ func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128
|
|||
|
||||
// CHECK-LABEL: func @fft_memrefs
|
||||
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
|
||||
"lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -852,7 +852,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>,
|
|||
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
|
||||
%grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>,
|
||||
%grad_offset: memref<8xf32>) -> () {
|
||||
"xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
"lmhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
|
||||
memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
|
||||
return
|
||||
|
@ -863,7 +863,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>,
|
|||
// CHECK-LABEL: func @batch_norm_inference_memrefs
|
||||
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
|
||||
"xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
"lmhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -874,7 +874,7 @@ func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf
|
|||
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>,
|
||||
%batch_var: memref<8xf32>) -> () {
|
||||
"xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
"lmhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -883,8 +883,8 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
|
|||
|
||||
// CHECK-LABEL: func @cholesky_memrefs
|
||||
func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () {
|
||||
"xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
"xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
"lmhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
"lmhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -892,7 +892,7 @@ func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291x
|
|||
|
||||
// CHECK-LABEL: func @infeed_memrefs
|
||||
func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
|
||||
"xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
|
||||
"lmhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -900,7 +900,7 @@ func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @outfeed_memrefs
|
||||
func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
|
||||
"xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
|
||||
"lmhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -908,7 +908,7 @@ func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @replica_id_memrefs
|
||||
func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
|
||||
"xla_lhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
|
||||
"lmhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -916,7 +916,7 @@ func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @triangular_solve_memrefs
|
||||
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
|
||||
"xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
|
||||
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
|
||||
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -925,9 +925,9 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %
|
|||
|
||||
// CHECK-LABEL: func @while_memrefs
|
||||
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
|
||||
"xla_lhlo.while"(%arg0, %arg_out) (
|
||||
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "xla_lhlo.terminator"() : () -> () }
|
||||
"lmhlo.while"(%arg0, %arg_out) (
|
||||
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "lmhlo.terminator"() : () -> () }
|
||||
) : (memref<i64>, memref<i64>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -936,9 +936,9 @@ func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @while_memrefs
|
||||
func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>) -> () {
|
||||
"xla_lhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) (
|
||||
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "xla_lhlo.terminator"() : () -> () }
|
||||
"lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) (
|
||||
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () }
|
||||
) : (memref<i64>, memref<5xf32>, memref<i64>, memref<5xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -947,7 +947,7 @@ func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<
|
|||
|
||||
// CHECK-LABEL: func @bitcast_memrefs
|
||||
func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
|
||||
"xla_lhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
|
||||
"lmhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -956,7 +956,7 @@ func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
|
|||
// CHECK-LABEL: func @scatter_memrefs
|
||||
func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>,
|
||||
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
|
||||
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
|
||||
"lmhlo.scatter" (%input, %indices, %updates, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
|
||||
%add = mhlo.add %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%add) : (tensor<f32>) -> ()
|
||||
|
@ -977,7 +977,7 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
|
|||
|
||||
// CHECK-LABEL: func @map_memrefs
|
||||
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
"lmhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = mhlo.add %a, %b : tensor<f32>
|
||||
"mhlo.return"(%c) : (tensor<f32>) -> ()
|
||||
|
@ -989,7 +989,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
|
|||
|
||||
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () {
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
"lmhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = mhlo.add %a, %b : tensor<f32>
|
||||
"mhlo.return"(%c) : (tensor<f32>) -> ()
|
||||
|
@ -1001,7 +1001,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
|
|||
|
||||
// CHECK-LABEL: func @rng_get_and_update_state_memrefs
|
||||
func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
|
||||
"xla_lhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> ()
|
||||
"lmhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -1010,7 +1010,7 @@ func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
|
|||
// CHECK-LABEL: func @sort_memrefs
|
||||
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
|
@ -1023,7 +1023,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
|||
// CHECK-LABEL: func @sort_memrefs
|
||||
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
|
@ -1036,7 +1036,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
|||
// CHECK-LABEL: func @sort_memrefs
|
||||
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
|
|
Loading…
Reference in New Issue