Rename `xla_hlo` dialect to `mhlo`

This is part of the current refactoring of the HLO related dialect.
`xla_hlo` will be reintroduced in a new form later.

PiperOrigin-RevId: 319916753
This commit is contained in:
Mehdi Amini 2020-07-07 04:51:24 +00:00 committed by Mehdi Amini
parent fa057cc0bc
commit 8900222fed
46 changed files with 1163 additions and 1174 deletions

View File

@ -17,12 +17,12 @@ limitations under the License.
// These ops are not necessarily orthogonal or optimized for transformation but
// for ease of expression in certain cases deemed important for client
// libraries (i.e. implicit broadcasting, helper ops, etc).
// This dialect is considered to exist in addition to augment the xla_hlo
// This dialect is considered to exist in addition to augment the mhlo
// dialect for ergonomic needs, not duplicate/replace it.
//
// The typical use of this dialect is for client libraries to be able to emit
// less constrained ops and rely on the conversion framework to lower any
// xla_chlo ops to canonical xla_hlo ops.
// xla_chlo ops to canonical mhlo ops.
//
// See: https://www.tensorflow.org/xla/operation_semantics
@ -44,7 +44,7 @@ def HLOClient_Dialect : Dialect {
let description = [{
This dialect contains ops that align closely with the API surface area
of the XlaBuilder C++ API, where such ops have semantics that go beyond
what exists in the lower level dialects (such as `xla_hlo`). Essentially,
what exists in the lower level dialects (such as `mhlo`). Essentially,
whenever the client library uses syntactic sugar or composition
of multiple ops for an API call, this dialect tries to model the API call
and provide conversion patterns to fully materialize into lower level
@ -65,7 +65,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
// shape broadcasting.
//
// These correspond to operations in the xla_hlo dialect without the
// These correspond to operations in the mhlo dialect without the
// "broadcast_" prefix, except that those ops require same-shaped operands and
// results.
//

View File

@ -37,12 +37,12 @@ class OpBuilder;
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc"
namespace xla_hlo {
namespace mhlo {
class XlaHloDialect : public Dialect {
public:
explicit XlaHloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "xla_hlo"; }
static StringRef getDialectNamespace() { return "mhlo"; }
// Registered hook to materialize a constant operation from a given attribute
// value with the desired resultant type.
@ -82,7 +82,7 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
// %1 = index_cast %0 : index to i64
// %2 = dim %arg0, 1 : memref<?x?xf32>
// %3 = index_cast %2 : index to i64
// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3)
// %4 = "mhlo.scalars_to_dimension_tensor"(%1, %3)
// : (i64, i64) -> tensor<2xi64>
//
// and returns %4 as the shape value.
@ -93,7 +93,7 @@ LogicalResult deriveShapeFromFirstOperand(
#define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
} // end namespace xla_hlo
} // end namespace mhlo
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_

View File

@ -29,8 +29,8 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
def HLO_Dialect : Dialect {
let name = "xla_hlo";
let cppNamespace = "xla_hlo";
let name = "mhlo";
let cppNamespace = "mhlo";
}
class HLO_Op<string mnemonic, list<OpTrait> traits> :

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
template <typename HloOpTy>
struct HloToLhloOpImpl {
@ -31,10 +31,10 @@ struct HloToLhloOpImpl {
template <typename HloOpTy>
using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
#define MAP_HLO_TO_LHLO(OpName) \
template <> \
struct HloToLhloOpImpl<xla_hlo::OpName> { \
using Type = xla_lhlo::OpName; \
#define MAP_HLO_TO_LHLO(OpName) \
template <> \
struct HloToLhloOpImpl<mhlo::OpName> { \
using Type = xla_lhlo::OpName; \
}
MAP_HLO_TO_LHLO(AbsOp);
@ -74,7 +74,7 @@ MAP_HLO_TO_LHLO(TanhOp);
#undef MAP_HLO_TO_LHLO
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_

View File

@ -464,7 +464,7 @@ struct XlaOpToStdScalarOp {
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
@ -472,8 +472,8 @@ struct XlaOpToStdScalarOp {
args, b);
}
// Implementation for HLO ops except xla_hlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
// Implementation for HLO ops except mhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, std::false_type>::value>>
@ -493,10 +493,11 @@ struct XlaOpToStdScalarOp {
op.getLoc(), comparison_direction, result_types, args, b);
}
// Implementation for xla_hlo::CompareOp.
template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
HloOpTy, xla_hlo::CompareOp>::value>>
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
// Implementation for mhlo::CompareOp.
template <typename HloOpTy,
typename =
std::enable_if_t<std::is_same<HloOpTy, mhlo::CompareOp>::value>>
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(

View File

@ -29,7 +29,7 @@ template <typename T>
class OperationPass;
class Pass;
namespace xla_hlo {
namespace mhlo {
/// Lowers HLO control flow ops to the Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
@ -55,10 +55,10 @@ std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
// necessary to export to XLA.
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
// fuse xla_hlo ops to kLoop/kInput fusion patterns
// fuse mhlo ops to kLoop/kInput fusion patterns
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
} // namespace xla_hlo
} // namespace mhlo
namespace xla_lhlo {

View File

@ -27,7 +27,7 @@ class LLVMTypeConverter;
class LowerToLLVMOptions;
class OwningRewritePatternList;
class BufferAssignmentPlacer;
namespace xla_hlo {
namespace mhlo {
// Collection of rewrite patterns for lowering a general dot product.
void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
@ -73,7 +73,7 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace xla_hlo
} // namespace mhlo
namespace xla_lhlo {

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
// Static initialization for XLA dialect registration.
static mlir::DialectRegistration<mlir::xla_hlo::XlaHloDialect> xla_hlo_ops;
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
xla_chlo_ops;
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;

View File

@ -60,7 +60,7 @@ limitations under the License.
namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
namespace xla_hlo {
namespace mhlo {
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
Attribute value, Type type,
@ -68,8 +68,7 @@ Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
// HLO dialect constants only support ElementsAttr unlike standard dialect
// constant which supports all attributes.
if (value.isa<ElementsAttr>())
return builder.create<xla_hlo::ConstOp>(loc, type,
value.cast<ElementsAttr>());
return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
return nullptr;
}
@ -167,7 +166,7 @@ void ConstOp::build(OpBuilder& builder, OperationState& result,
}
// TODO: support other XLA specific types.
assert(type && "unsupported attribute type for building xla_hlo.constant");
assert(type && "unsupported attribute type for building mhlo.constant");
result.types.push_back(type);
result.addAttribute("value", value);
}
@ -387,7 +386,7 @@ static LogicalResult Verify(GetTupleElementOp op) {
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
if (auto tupleOp =
dyn_cast_or_null<xla_hlo::TupleOp>(getOperand().getDefiningOp())) {
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
return tupleOp.getOperand(index().getLimitedValue());
}
@ -693,10 +692,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
}
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
auto real_op =
dyn_cast_or_null<xla_hlo::RealOp>(getOperand(0).getDefiningOp());
auto imag_op =
dyn_cast_or_null<xla_hlo::ImagOp>(getOperand(1).getDefiningOp());
auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp());
auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp());
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
return real_op.getOperand();
}
@ -727,7 +724,7 @@ void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) {
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op =
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(1);
}
@ -740,7 +737,7 @@ void RealOp::build(OpBuilder& builder, OperationState& state, Value val) {
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op =
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(0);
}
@ -1148,7 +1145,7 @@ static LogicalResult Verify(MapOp op) {
// RecvOp
//===----------------------------------------------------------------------===//
// Checks that the result type is of the form `tuple<any_type, xla_hlo::token>`
// Checks that the result type is of the form `tuple<any_type, mhlo::token>`
static LogicalResult Verify(RecvOp op) {
auto result_ty = op.getResult().getType().cast<TupleType>();
auto subtypes = result_ty.getTypes();
@ -2020,7 +2017,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
//===----------------------------------------------------------------------===//
// xla_hlo Dialect Interfaces
// mhlo Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
@ -2032,7 +2029,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
BlockAndValueMapping& valueMapping) const final {
return true;
}
// Operations in xla_hlo dialect are always legal to inline since they are
// Operations in mhlo dialect are always legal to inline since they are
// pure.
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
return true;
@ -2041,7 +2038,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// xla_hlo Dialect Constructor
// mhlo Dialect Constructor
//===----------------------------------------------------------------------===//
XlaHloDialect::XlaHloDialect(MLIRContext* context)
@ -2061,8 +2058,7 @@ Type XlaHloDialect::parseType(DialectAsmParser& parser) const {
if (parser.parseKeyword(&data_type)) return Type();
if (data_type == "token") return TokenType::get(getContext());
parser.emitError(parser.getNameLoc())
<< "unknown xla_hlo type: " << data_type;
parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
return nullptr;
}
@ -2071,7 +2067,7 @@ void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const {
os << "token";
return;
}
os << "<unknown xla_hlo type>";
os << "<unknown mhlo type>";
}
//===----------------------------------------------------------------------===//
@ -2106,5 +2102,5 @@ LogicalResult deriveShapeFromFirstOperand(
return success();
}
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -30,7 +30,7 @@ namespace xla_chlo {
namespace {
// Converts binary ops that statically are determined to not broadcast directly
// to the corresponding xla_hlo non-broadcasting op.
// to the corresponding mhlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
@ -63,7 +63,7 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
};
// Converts a binary op with ranked broadcasting operands to explicitly
// broadcast and invoke the corresponding xla_hlo non-broadcasting op.
// broadcast and invoke the corresponding mhlo non-broadcasting op.
// Note that dynamic broadcasting supported by this pattern is only valid for
// "numpy" broadcasting semantics as defined here:
// https://docs.scipy.org/doc/numpy/reference/ufuncs.html
@ -136,7 +136,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
// properly.
auto lhs_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
Value broadcasted_lhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
lhs_type.getElementType()),
@ -144,7 +144,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
auto rhs_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
Value broadcasted_rhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
rhs_type.getElementType()),
@ -182,23 +182,21 @@ struct HloBinaryElementwiseAdaptor {
};
struct HloComplexAdaptor {
static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op,
Type result_type, Value broadcasted_lhs,
Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<xla_hlo::ComplexOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
struct HloCompareAdaptor {
static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op,
Type result_type, Value broadcasted_lhs,
Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction());
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::CompareOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction());
}
};
@ -214,28 +212,27 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
patterns);
POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op);
POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp);
POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp);
POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp);
POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp);
POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp);
POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp);
POPULATE_BCAST(BroadcastShiftRightArithmeticOp,
xla_hlo::ShiftRightArithmeticOp);
POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp);
POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp);
POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp);
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
// Broadcasting ops requiring special construction.
PopulateForBinaryOp<BroadcastComplexOp, xla_hlo::ComplexOp,
HloComplexAdaptor>(context, patterns);
PopulateForBinaryOp<BroadcastCompareOp, xla_hlo::CompareOp,
HloCompareAdaptor>(context, patterns);
PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
context, patterns);
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
context, patterns);
}
} // namespace xla_chlo

View File

@ -32,8 +32,8 @@ struct TestChloLegalizeToHloPass
OwningRewritePatternList conversionPatterns;
conversionTarget.addIllegalDialect<XlaHloClientDialect>();
// Consider the xla_hlo dialect legal for tests.
conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>();
// Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<mhlo::XlaHloDialect>();
// The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();

View File

@ -37,7 +37,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
template <typename T>
@ -128,20 +128,20 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return success();
}
};
struct HloToLhloDynamicBroadcastInDimOpConverter
: public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
public:
using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
Value resultBuffer = InsertDynamicAllocAndDealloc(
@ -162,7 +162,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
// and size of the target dimension if size-1 dimension expansion is
// necessary.
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
@ -220,12 +220,12 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
}
};
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
public:
using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
xla_hlo::ReduceOp op, ArrayRef<Value> operands,
mhlo::ReduceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
// TODO(b/137624192) Implement variadic reduce.
@ -314,10 +314,10 @@ class HloToLhloTensorStoreOpConverter
// "xla_lhlo.fusion"() ({
// %0 = tensor_load %arg1 : memref<2x2xf32>
// %1 = tensor_load %arg2 : memref<2x2xf32>
// %2 = "xla_hlo.add"(%0, %1) :
// %2 = "mhlo.add"(%0, %1) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// %3 = tensor_load %arg0 : memref<2x2xf32>
// %4 = "xla_hlo.multiply"(%2, %3) :
// %4 = "mhlo.multiply"(%2, %3) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// tensor_store %4, %arg3 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> ()
@ -344,8 +344,8 @@ class HloToLhloTensorStoreOpConverter
// FuncOp signature conversion example:
//
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>,
// %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
// tensor<4xf32> %1 = "mhlo.add"(%arg0, %0) : (tensor<4xf32>,
// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
// }
//
@ -388,7 +388,7 @@ struct HloLegalizeToLhlo
target.addIllegalOp<mlir::TensorStoreOp>();
target.addLegalOp<ModuleTerminatorOp>();
target.addLegalOp<TensorFromElementsOp>();
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
target.addIllegalDialect<mhlo::XlaHloDialect>();
BufferAssignmentTypeConverter converter;
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
@ -442,38 +442,38 @@ void populateHLOToLHLOConversionPattern(
// clang-format off
patterns->insert<
HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloOpConverter<xla_hlo::AbsOp>,
HloToLhloOpConverter<xla_hlo::AddOp>,
HloToLhloOpConverter<xla_hlo::AndOp>,
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
HloToLhloOpConverter<xla_hlo::CeilOp>,
HloToLhloOpConverter<xla_hlo::CompareOp>,
HloToLhloOpConverter<xla_hlo::ComplexOp>,
HloToLhloOpConverter<xla_hlo::ConstOp>,
HloToLhloOpConverter<xla_hlo::ConvOp>,
HloToLhloOpConverter<xla_hlo::ConvertOp>,
HloToLhloOpConverter<xla_hlo::CopyOp>,
HloToLhloOpConverter<xla_hlo::CosOp>,
HloToLhloOpConverter<xla_hlo::DivOp>,
HloToLhloOpConverter<xla_hlo::DotOp>,
HloToLhloOpConverter<xla_hlo::ExpOp>,
HloToLhloOpConverter<xla_hlo::GatherOp>,
HloToLhloOpConverter<xla_hlo::ImagOp>,
HloToLhloOpConverter<xla_hlo::IotaOp>,
HloToLhloOpConverter<xla_hlo::LogOp>,
HloToLhloOpConverter<xla_hlo::MaxOp>,
HloToLhloOpConverter<xla_hlo::MinOp>,
HloToLhloOpConverter<xla_hlo::MulOp>,
HloToLhloOpConverter<xla_hlo::NegOp>,
HloToLhloOpConverter<xla_hlo::RealOp>,
HloToLhloOpConverter<xla_hlo::RemOp>,
HloToLhloOpConverter<xla_hlo::RsqrtOp>,
HloToLhloOpConverter<xla_hlo::ReshapeOp>,
HloToLhloOpConverter<xla_hlo::SelectOp>,
HloToLhloOpConverter<xla_hlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SqrtOp>,
HloToLhloOpConverter<xla_hlo::SubOp>,
HloToLhloOpConverter<xla_hlo::TanhOp>,
HloToLhloOpConverter<mhlo::AbsOp>,
HloToLhloOpConverter<mhlo::AddOp>,
HloToLhloOpConverter<mhlo::AndOp>,
HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
HloToLhloOpConverter<mhlo::CeilOp>,
HloToLhloOpConverter<mhlo::CompareOp>,
HloToLhloOpConverter<mhlo::ComplexOp>,
HloToLhloOpConverter<mhlo::ConstOp>,
HloToLhloOpConverter<mhlo::ConvOp>,
HloToLhloOpConverter<mhlo::ConvertOp>,
HloToLhloOpConverter<mhlo::CopyOp>,
HloToLhloOpConverter<mhlo::CosOp>,
HloToLhloOpConverter<mhlo::DivOp>,
HloToLhloOpConverter<mhlo::DotOp>,
HloToLhloOpConverter<mhlo::ExpOp>,
HloToLhloOpConverter<mhlo::GatherOp>,
HloToLhloOpConverter<mhlo::ImagOp>,
HloToLhloOpConverter<mhlo::IotaOp>,
HloToLhloOpConverter<mhlo::LogOp>,
HloToLhloOpConverter<mhlo::MaxOp>,
HloToLhloOpConverter<mhlo::MinOp>,
HloToLhloOpConverter<mhlo::MulOp>,
HloToLhloOpConverter<mhlo::NegOp>,
HloToLhloOpConverter<mhlo::RealOp>,
HloToLhloOpConverter<mhlo::RemOp>,
HloToLhloOpConverter<mhlo::RsqrtOp>,
HloToLhloOpConverter<mhlo::ReshapeOp>,
HloToLhloOpConverter<mhlo::SelectOp>,
HloToLhloOpConverter<mhlo::SignOp>,
HloToLhloOpConverter<mhlo::SqrtOp>,
HloToLhloOpConverter<mhlo::SubOp>,
HloToLhloOpConverter<mhlo::TanhOp>,
HloToLhloReduceOpConverter,
HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter
@ -489,5 +489,5 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -35,7 +35,7 @@ limitations under the License.
using mlir::PassRegistration;
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
struct LegalizeControlFlow
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
@ -51,7 +51,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
OpBuilder* builder) {
for (auto& old_block : region->getBlocks()) {
Block* block = mapper.lookup(&old_block);
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
auto return_op = dyn_cast<mhlo::ReturnOp>(block->getTerminator());
if (!return_op) continue;
builder->setInsertionPointToEnd(block);
builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
@ -61,7 +61,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
return success();
}
LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
Operation* op_inst = if_op.getOperation();
mlir::OpBuilder builder(if_op);
auto orig_block = op_inst->getBlock();
@ -106,13 +106,13 @@ LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
return success();
}
LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
// Converts an XLA while loop into control flow. This generates a set of MLIR
// blocks and branches, along with inlining the regions provided by the XLA
// while loop. The structure should be similar to below:
//
// <prior operations>
// %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
// %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
// <post operations>
auto* op_inst = while_op.getOperation();
mlir::OpBuilder builder(while_op);
@ -147,7 +147,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// extract_element and conditional branch. This changes the block below:
// ^cond(%0):
// <inlined conditional region>
// "xla_hlo".return(%1)
// "mhlo".return(%1)
//
// Into:
// ^cond(%0):
@ -156,14 +156,14 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// cond_br %2, ^body(%0), ^tail(%0) // Branch.
builder.setInsertionPointToStart(cond_block);
// Replace the xla_hlo::ReturnOp with a branch back to the condition block.
// This is required as the xla_hlo::ReturnOp is used to mark the end of a
// Replace the mhlo::ReturnOp with a branch back to the condition block.
// This is required as the mhlo::ReturnOp is used to mark the end of a
// block for regions nested inside of a operations (MLIR ReturnOp cannot be
// nested within an non-function region).
for (auto& block : while_op.cond()) {
auto new_block = mapper.lookup(&block);
auto return_op = dyn_cast<xla_hlo::ReturnOp>(new_block->getTerminator());
auto return_op = dyn_cast<mhlo::ReturnOp>(new_block->getTerminator());
if (!return_op) continue;
builder.setInsertionPointToEnd(new_block);
@ -183,7 +183,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// conditional block. This changes the block below:
// ^body(%0):
// <inlined body block>
// "xla_hlo".return(%1)
// "mhlo".return(%1)
//
// Into:
// ^body(%0):
@ -191,8 +191,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// br ^cond(%0) // Branch.
for (auto& block : while_op.body()) {
auto new_block = mapper.lookup(&block);
auto return_op =
dyn_cast<mlir::xla_hlo::ReturnOp>(new_block->getTerminator());
auto return_op = dyn_cast<mlir::mhlo::ReturnOp>(new_block->getTerminator());
if (!return_op) continue;
builder.setInsertionPointToEnd(new_block);
builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
@ -224,14 +223,14 @@ void LegalizeControlFlow::runOnFunction() {
}
}
} // namespace
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
mlir::xla_hlo::createLegalizeControlFlowPass() {
mlir::mhlo::createLegalizeControlFlowPass() {
return std::make_unique<LegalizeControlFlow>();
}
static PassRegistration<mlir::xla_hlo::LegalizeControlFlow> legalize_cf_pass(
static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass(
"xla-legalize-control-flow",
"Legalize from XLA control flow to MLIR control flow");

View File

@ -28,14 +28,14 @@ namespace mlir {
namespace {
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"
} // end anonymous namespace
namespace xla_hlo {
namespace mhlo {
namespace {
class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
LogicalResult matchAndRewrite(mhlo::CompareOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.lhs();
auto rhs = op.rhs();
@ -68,11 +68,11 @@ class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
}
};
class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
LogicalResult matchAndRewrite(mhlo::CompareOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.lhs();
auto rhs = op.rhs();
@ -109,11 +109,11 @@ class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
// convert the integer constant to iota result type. For complex types, the real
// part is replaced with the generated constant and the imaginary part is
// replaced with zero tensor.
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::IotaOp op,
LogicalResult matchAndRewrite(mhlo::IotaOp op,
PatternRewriter &rewriter) const override {
auto output_type = op.getType().cast<ShapedType>();
auto output_size = output_type.getNumElements();
@ -168,8 +168,7 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
auto imag_zeroes =
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const,
imag_zeroes);
rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iota_const, imag_zeroes);
return success();
}
};
@ -197,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
/// Perform the lowering to standard dialect.
void LegalizeToStandard::runOnFunction() {
OwningRewritePatternList patterns;
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
mlir::mhlo::PopulateXlaToStdPatterns(&patterns, &getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
static PassRegistration<LegalizeToStandard> legalize_pass(
"xla-legalize-to-std", "Legalize from XLA dialect to standard dialect");
} // end namespace xla_hlo
} // end namespace mhlo
} // end namespace mlir

View File

@ -84,14 +84,14 @@ Value TransposeReshape(Value arg, mlir::Location loc,
transposed_shape.push_back(arg_shape[val]);
}
auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
auto transpose_result = rewriter->create<mlir::xla_hlo::TransposeOp>(
auto transpose_result = rewriter->create<mlir::mhlo::TransposeOp>(
loc, transpose_type, arg, transpose_permutation_attr);
// Return the final result.
auto reshaped_type =
RankedTensorType::get({left_size, right_size}, element_type);
return rewriter->create<mlir::xla_hlo::ReshapeOp>(loc, reshaped_type,
transpose_result);
return rewriter->create<mlir::mhlo::ReshapeOp>(loc, reshaped_type,
transpose_result);
}
Value ProcessDotArg(Value arg, mlir::Location loc,
@ -125,8 +125,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc,
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
}
struct GeneralDotConvert
: public OpRewritePattern<mlir::xla_hlo::DotGeneralOp> {
struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
// Attempts to lower a General Dot operator to a standard Dot operator.
// General dots include batching dimensions and can have collapsing
// dimensions along any axis. Inserting correctly arrange transpose and
@ -138,7 +137,7 @@ struct GeneralDotConvert
explicit GeneralDotConvert(MLIRContext *context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op,
LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op,
PatternRewriter &rewriter) const override {
auto dot_element_type = mlir::getElementTypeOrSelf(op);
@ -162,11 +161,11 @@ struct GeneralDotConvert
auto new_dot_type =
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
auto new_dot_op = rewriter.create<mlir::xla_hlo::DotOp>(
auto new_dot_op = rewriter.create<mlir::mhlo::DotOp>(
op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config()));
rewriter.replaceOpWithNewOp<mlir::xla_hlo::ReshapeOp>(op, op.getType(),
new_dot_op);
rewriter.replaceOpWithNewOp<mlir::mhlo::ReshapeOp>(op, op.getType(),
new_dot_op);
return success();
}
};
@ -176,15 +175,14 @@ struct LegalizeGeneralDot
/// Lower all general dots that can be represented as a non-batched matmul.
void runOnFunction() override {
OwningRewritePatternList patterns;
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
&getContext());
mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
} // namespace
void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(
void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
OwningRewritePatternList *patterns, MLIRContext *ctx) {
patterns->insert<GeneralDotConvert>(ctx);
}

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
@ -86,5 +86,5 @@ void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
patterns->insert<ClampWithBroadcastConvert>(context);
}
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
@ -33,7 +33,7 @@ struct TestMaterializeBroadcastsPass
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns;
// Consider the xla_hlo dialect legal for tests.
// Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<XlaHloDialect>();
// The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass
} // namespace
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir
static mlir::PassRegistration<mlir::xla_hlo::TestMaterializeBroadcastsPass>
pass("test-xla-materialize-broadcasts",
"Test pass for materializing 'broadcast_dimensions' attributes");
static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass(
"test-xla-materialize-broadcasts",
"Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -60,7 +60,7 @@ limitations under the License.
// shape dialect once it is ready.
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
using llvm::EquivalenceClasses;
@ -544,7 +544,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
}
FusionOp fusion =
b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs);
b.create<mhlo::FusionOp>(fused_loc, output_types, inputs);
Region& region = fusion.fused_computation();
region.push_back(new Block);
Block& block = region.front();
@ -552,7 +552,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
op->moveBefore(&block, block.end());
}
b.setInsertionPoint(&block, block.end());
b.create<xla_hlo::ReturnOp>(fused_loc, outputs);
b.create<mhlo::ReturnOp>(fused_loc, outputs);
for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
Value output = std::get<0>(output_and_result);
@ -572,8 +572,8 @@ std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
return std::make_unique<XlaHloFusion>();
}
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass(
"xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns.");
static PassRegistration<XlaHloFusion> mhlo_fusion_pass(
"xla-hlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
@ -81,5 +81,5 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
return std::make_unique<SinkConstantsToControlFlow>();
}
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
@ -40,12 +40,12 @@ Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
if (shape_value) {
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
loc, result_type, value_1d, shape_value, dims);
}
assert(result_type.hasStaticShape());
return rewriter.create<xla_hlo::BroadcastInDimOp>(loc, result_type, value_1d,
dims);
return rewriter.create<mhlo::BroadcastInDimOp>(loc, result_type, value_1d,
dims);
}
// Calculate the shape value of operand, assuming it is a dynamic shape with
@ -89,25 +89,25 @@ Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
auto epsilon_tensor_attr =
DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
Value epsilon =
rewriter.create<xla_hlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
if (broadcast_to_type.hasStaticShape()) {
return rewriter.create<xla_hlo::BroadcastInDimOp>(
return rewriter.create<mhlo::BroadcastInDimOp>(
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
}
Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
op->getLoc(), broadcast_to_type, epsilon, shape_value,
/*broadcast_dims=*/dims);
}
class UnfuseBatchNormInferencePattern
: public OpRewritePattern<xla_hlo::BatchNormInferenceOp> {
: public OpRewritePattern<mhlo::BatchNormInferenceOp> {
public:
using OpRewritePattern<xla_hlo::BatchNormInferenceOp>::OpRewritePattern;
using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op,
LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,
PatternRewriter& rewriter) const override {
// Enforce type invariants.
// Note that we deduce the actual element type from the variance,
@ -132,9 +132,9 @@ class UnfuseBatchNormInferencePattern
if (!epsilon) {
return failure();
}
Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(),
bn_op.variance(), epsilon);
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);
Value stddev =
rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon);
stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev);
// Broadcast all terms.
Value shape_value;
@ -156,14 +156,13 @@ class UnfuseBatchNormInferencePattern
// Compute:
// scale * (input - mean) / stddev + offset
Value result = rewriter.create<xla_hlo::SubOp>(
bn_op.getLoc(), bn_op.operand(), broadcast_mean);
result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result,
broadcast_scale);
result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result,
broadcast_stddev);
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result,
broadcast_offset);
Value result = rewriter.create<mhlo::SubOp>(bn_op.getLoc(), bn_op.operand(),
broadcast_mean);
result =
rewriter.create<mhlo::MulOp>(bn_op.getLoc(), result, broadcast_scale);
result =
rewriter.create<mhlo::DivOp>(bn_op.getLoc(), result, broadcast_stddev);
rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, result, broadcast_offset);
return success();
}
@ -180,5 +179,5 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
patterns->insert<UnfuseBatchNormInferencePattern>(context);
}
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass
} // namespace
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir
static mlir::PassRegistration<mlir::xla_hlo::TestUnfuseBatchNormPass> pass(
static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass(
"test-xla-unfuse-batch-norm",
"Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -182,7 +182,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
// This code has been adapted from IREE's
// (https://github.com/google/iree/) xla_hlo -> linalg conversion.
// (https://github.com/google/iree/) mhlo -> linalg conversion.
LogicalResult matchAndRewrite(
xla_lhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
@ -348,14 +348,14 @@ class BroadcastConverter
class HloBroadcastInDimConverter
: public DataMovementOpConverter<HloBroadcastInDimConverter,
xla_hlo::BroadcastInDimOp, false> {
mhlo::BroadcastInDimOp, false> {
public:
using DataMovementOpConverter<HloBroadcastInDimConverter,
xla_hlo::BroadcastInDimOp,
mhlo::BroadcastInDimOp,
false>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(
xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) {
mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
auto resultType = getXLAOpResultType<false>(broadcastOp);
auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>();
@ -845,7 +845,7 @@ struct HloLegalizeToLinalg
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
auto func = getFunction();
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
signalPassFailure();
}
@ -863,40 +863,40 @@ static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
} // namespace xla_lhlo
namespace xla_hlo {
namespace mhlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
HloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
PointwiseToLinalgConverter<xla_hlo::ComplexOp, false>,
PointwiseToLinalgConverter<xla_hlo::ConvertOp, false>,
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
PointwiseToLinalgConverter<xla_hlo::CosOp, false>,
PointwiseToLinalgConverter<xla_hlo::DivOp, false>,
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
PointwiseToLinalgConverter<xla_hlo::ImagOp, false>,
PointwiseToLinalgConverter<xla_hlo::LogOp, false>,
PointwiseToLinalgConverter<xla_hlo::MaxOp, false>,
PointwiseToLinalgConverter<xla_hlo::MinOp, false>,
PointwiseToLinalgConverter<xla_hlo::MulOp, false>,
PointwiseToLinalgConverter<xla_hlo::NegOp, false>,
PointwiseToLinalgConverter<xla_hlo::RealOp, false>,
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
PointwiseToLinalgConverter<xla_hlo::SinOp, false>,
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
ReverseConverter<xla_hlo::ReverseOp, false>,
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
@ -905,5 +905,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass(
"hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect");
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
// TODO(frgossen): Make it variadic.
@ -69,7 +69,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
operandTy.getElementType());
Value flatOperand = rewriter.create<xla_hlo::DynamicReshapeOp>(
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
loc, flatTensorTy, operand, flatShapeAsDimTensor);
// Generate IR for the actual operation.
@ -80,7 +80,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
rewriter.getIndexType());
Value shapeAsExtentTensor =
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
Value result = rewriter.create<xla_hlo::DynamicReshapeOp>(
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
loc, operandTy, flatResult, shapeAsExtentTensor);
rewriter.replaceOp(op, result);
@ -184,5 +184,5 @@ static PassRegistration<TransformUnrankedHloPass> transform_unranked_hlo_pass(
"transform-unranked-hlo",
"Realize element-wise operations on ranked tensors where possible");
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -2,107 +2,107 @@
// CHECK-LABEL: add_fold
func @add_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%1 = xla_hlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<[6, 8, 10, 12]>
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%1 = mhlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
// CHECK: mhlo.constant dense<[6, 8, 10, 12]>
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: add_scalar_fold
func @add_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<1> : tensor<4xi64>
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<6>
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<1> : tensor<4xi64>
%1 = mhlo.constant dense<5> : tensor<4xi64>
// CHECK: mhlo.constant dense<6>
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: add_fold_float
func @add_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
%0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
%2 = "mhlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: sub_scalar_fold
func @sub_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<5> : tensor<4xi64>
%1 = xla_hlo.constant dense<1> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<4>
%2 = "xla_hlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<5> : tensor<4xi64>
%1 = mhlo.constant dense<1> : tensor<4xi64>
// CHECK: mhlo.constant dense<4>
%2 = "mhlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: multiply_scalar_fold
func @multiply_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<5> : tensor<4xi64>
%1 = xla_hlo.constant dense<3> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<15>
%2 = "xla_hlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<5> : tensor<4xi64>
%1 = mhlo.constant dense<3> : tensor<4xi64>
// CHECK: mhlo.constant dense<15>
%2 = "mhlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: divide_scalar_fold
func @divide_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<1>
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<7> : tensor<4xi64>
%1 = mhlo.constant dense<5> : tensor<4xi64>
// CHECK: mhlo.constant dense<1>
%2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: divide_fold_float
func @divide_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
%2 = "mhlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: max_scalar_fold
func @max_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<7>
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<7> : tensor<4xi64>
%1 = mhlo.constant dense<5> : tensor<4xi64>
// CHECK: mhlo.constant dense<7>
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: max_fold_float
func @max_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: min_scalar_fold
func @min_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<-5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<-5>
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<7> : tensor<4xi64>
%1 = mhlo.constant dense<-5> : tensor<4xi64>
// CHECK: mhlo.constant dense<-5>
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: min_fold_float
func @min_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: concatenate_noop
func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>
%0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
%0 = "mhlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
// CHECK: return [[ARG]]
return %0 : tensor<4xi32>
@ -112,7 +112,7 @@ func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> {
// CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32>
// CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32>
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
// CHECK: return [[ARG0]]
return %0 : tensor<4xi32>
@ -120,34 +120,34 @@ func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) ->
// CHECK-LABEL: concatenate_empty_bool
func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> {
// CHECK: xla_hlo.constant
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
// CHECK: mhlo.constant
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
return %0 : tensor<0xi1>
}
// CHECK-LABEL: concatenate_empty_int
func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> {
// CHECK: xla_hlo.constant
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
// CHECK: mhlo.constant
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
return %0 : tensor<0xi32>
}
// CHECK-LABEL: concatenate_empty_float
func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
// CHECK: xla_hlo.constant
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
// CHECK: mhlo.constant
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
return %0 : tensor<0xf32>
}
// CHECK-LABEL: concatenate_const_1D
func @concatenate_const_1D() -> tensor<4xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]>
%0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32>
%1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[0, 1, 2, 3]>
%0 = mhlo.constant dense<[0, 1]> : tensor<2xi32>
%1 = mhlo.constant dense<[2, 3]> : tensor<2xi32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
// CHECK: return [[VAL]]
return %2 : tensor<4xi32>
@ -155,11 +155,11 @@ func @concatenate_const_1D() -> tensor<4xi32> {
// CHECK-LABEL: concatenate_const_1D_float
func @concatenate_const_1D_float() -> tensor<4xf32> {
// CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
// CHECK: [[VAL:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
%0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
%1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
%0 = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
%1 = mhlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
// CHECK: return [[VAL]]
return %2 : tensor<4xf32>
@ -167,12 +167,12 @@ func @concatenate_const_1D_float() -> tensor<4xf32> {
// CHECK-LABEL: concatenate_const_2D_vertical
func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: ]>
%0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
%1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
%0 = mhlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
%1 = mhlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
// CHECK: return [[VAL]]
return %2 : tensor<2x2xi32>
@ -180,12 +180,12 @@ func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
// CHECK-LABEL: concatenate_const_2D_horizontal
func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
// CHECK-SAME: [0, 2], [1, 3]
// CHECK-SAME: ]>
%0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
%1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
%0 = mhlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
%1 = mhlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
// CHECK: return [[VAL]]
return %2 : tensor<2x2xi32>
@ -193,40 +193,40 @@ func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
// CHECK-LABEL: dynamic_slice_variable_start
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
// CHECK: "xla_hlo.dynamic-slice"
%1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
// CHECK: "mhlo.dynamic-slice"
%1 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
return %1 : tensor<1x4xi32>
}
// CHECK-LABEL: dynamic_slice_constant_start
func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
// CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64>
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>}
// CHECK: return %[[RESULT]] : tensor<2xi32>
%0 = xla_hlo.constant dense<1> : tensor<i64>
%1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
%0 = mhlo.constant dense<1> : tensor<i64>
%1 = "mhlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
return %1 : tensor<2xi32>
}
// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape
func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<?x4xi32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
// CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64>
// CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64>
// CHECK: return %[[RESULT]] : tensor<?x4xi32>
%0 = xla_hlo.constant dense<1> : tensor<i64>
%1 = xla_hlo.constant dense<0> : tensor<i64>
%2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
%0 = mhlo.constant dense<1> : tensor<i64>
%1 = mhlo.constant dense<0> : tensor<i64>
%2 = "mhlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
return %2 : tensor<?x4xi32>
}
// CHECK-LABEL: slice_2D_noop
// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64>
func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
%0 = "xla_hlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
%0 = "mhlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<2x2xi64>
@ -234,80 +234,80 @@ func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
// CHECK-LABEL: slice_1D_fold
func @slice_1D_fold() -> tensor<2xi64> {
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<[7, 9]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: mhlo.constant dense<[7, 9]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
return %1 : tensor<2xi64>
}
// CHECK-LABEL: slice_1D_fp
func @slice_1D_fp() -> tensor<2xf32> {
%0 = xla_hlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
// CHECK: xla_hlo.constant dense<[7.000000e+00, 9.000000e+00]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
%0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
// CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
return %1 : tensor<2xf32>
}
// CHECK-LABEL: slice_1D_strided_fold
func @slice_1D_strided_fold() -> tensor<2xi64> {
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<[7, 10]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: mhlo.constant dense<[7, 10]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
return %1 : tensor<2xi64>
}
// CHECK-LABEL: slice_2D_fold
func @slice_2D_fold() -> tensor<2x2xi64> {
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: xla_hlo.constant dense<[
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: mhlo.constant dense<[
// CHECK-SAME: [6, 7],
// CHECK-SAME: [10, 11]
// CHECK-SAME: ]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
return %1 : tensor<2x2xi64>
}
// CHECK-LABEL: slice_2D_fold_horizontal
func @slice_2D_fold_horizontal() -> tensor<1x4xi64> {
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: xla_hlo.constant dense<[
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: mhlo.constant dense<[
// CHECK-SAME: [0, 1, 2, 3]
// CHECK-SAME: ]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
return %1 : tensor<1x4xi64>
}
// CHECK-LABEL: slice_2D_fold_vertical
func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: xla_hlo.constant dense<[
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: mhlo.constant dense<[
// CHECK-SAME: [2], [6], [10], [14]
// CHECK-SAME: ]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
return %1 : tensor<4x1xi64>
}
// CHECK-LABEL: slice_concat_fold_first
func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return %arg0
return %1 : tensor<1x5xf32>
}
// CHECK-LABEL: slice_concat_fold_second
func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return %arg1
return %1 : tensor<1x5xf32>
}
// CHECK-LABEL: slice_concat_fold_second_with_slice
func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>)
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<1x4xf32>
@ -315,9 +315,9 @@ func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<
// CHECK-LABEL: slice_concat_fold_middle
func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>)
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<1x5xf32>
@ -325,11 +325,11 @@ func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %
// CHECK-LABEL: slice_concat_fold_two
func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> {
// CHECK: [[CONCAT:%.+]] = "xla_hlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64}
%0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64}
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>)
// CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<2x5xf32>
@ -338,72 +338,72 @@ func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg
// CHECK-LABEL: func @broadcast_in_dim_identity
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
// CHECK: return %arg0
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts
func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
// CHECK: xla_hlo.broadcast_in_dim
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32>
// CHECK: mhlo.broadcast_in_dim
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation
func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: xla_hlo.broadcast_in_dim
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: mhlo.broadcast_in_dim
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> {
// CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
%0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
// CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
// CHECK: return %[[RESULT]] : tensor<5x4xf32>
return %0 : tensor<5x4xf32>
}
// CHECK-LABEL: @complex_expand_fold
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
%1 = "xla_hlo.real"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "xla_hlo.imag"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
%1 = "mhlo.real"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "mhlo.imag"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
// CHECK: return %arg0, %arg1
return %1, %2 : tensor<4xf32>, tensor<4xf32>
}
// CHECK-LABEL: @complex_collapse_fold
func @complex_collapse_fold(%arg0: tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>> {
%0 = "xla_hlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%1 = "xla_hlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "xla_hlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
%0 = "mhlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%1 = "mhlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "mhlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
// CHECK: return %arg0
return %2 : tensor<4xcomplex<f32>>
}
// CHECK-LABEL: @dynamic_iota_is_static
func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> {
// CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
// CHECK: return [[RESULT]]
%0 = "xla_hlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32>
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: @iota_not_lowered_to_constant
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
// CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
// CHECK: return [[RESULT]]
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: @unary_einsum
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: "mhlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
%0 = "mhlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -411,30 +411,30 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK: return [[ARG]]
%0 = "xla_hlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32>
%0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
// CHECK: xla_hlo.reshape
%0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32>
// CHECK: mhlo.reshape
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32>
return %0 : tensor<4x1xf32>
}
// CHECK-LABEL: do_not_dce_while_with_outfeed
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: xla_hlo.while
%0 = "xla_hlo.while"(%arg0) ( {
// CHECK: mhlo.while
%0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.create_token"() : () -> !xla_hlo.token
%1 = "mhlo.create_token"() : () -> !mhlo.token
// Side-effecting op outfeed present inside while.
%2 = "xla_hlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor<i64>, !xla_hlo.token) -> !xla_hlo.token
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
%2 = "mhlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor<i64>, !mhlo.token) -> !mhlo.token
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %arg0 : tensor<i64>
@ -442,15 +442,15 @@ func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-LABEL: dce_while_without_side_effect
func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-NOT: xla_hlo.while
%0 = "xla_hlo.while"(%arg0) ( {
// CHECK-NOT: mhlo.while
%0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.create_token"() : () -> !xla_hlo.token
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
%1 = "mhlo.create_token"() : () -> !mhlo.token
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %arg0 : tensor<i64>

View File

@ -4,7 +4,7 @@
// representative op for detailed broadcast semantics.
// CHECK-LABEL: @addWithoutBroadcast
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.add %arg0, %arg1
// CHECK: mhlo.add %arg0, %arg1
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -20,9 +20,9 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]]
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]]
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
@ -41,9 +41,9 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
@ -62,9 +62,9 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
@ -76,7 +76,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// Verifies that broadcast_dimensions validity checks are valid.
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK: xla_hlo.add
// CHECK: mhlo.add
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
@ -85,7 +85,7 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<
// Verifies that broadcast_dimensions validity checks are valid.
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
// CHECK: xla_hlo.add
// CHECK: mhlo.add
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
@ -113,7 +113,7 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1:
// expansions. Tests below merely verify that the op has an expansion.
// CHECK-LABEL: @andWithoutBroadcast
func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.and %arg0, %arg1
// CHECK: mhlo.and %arg0, %arg1
%0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
@ -121,7 +121,7 @@ func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
// -----
// CHECK-LABEL: @atan2WithoutBroadcast
func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.atan2 %arg0, %arg1
// CHECK: mhlo.atan2 %arg0, %arg1
%0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -129,7 +129,7 @@ func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
// -----
// CHECK-LABEL: @compareWithoutBroadcast
func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> {
// CHECK: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
@ -137,7 +137,7 @@ func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// -----
// CHECK-LABEL: @complexWithoutBroadcast
func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> {
// CHECK: "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
// CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
return %0 : tensor<4xcomplex<f32>>
}
@ -145,7 +145,7 @@ func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// -----
// CHECK-LABEL: @divideWithoutBroadcast
func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.divide %arg0, %arg1
// CHECK: mhlo.divide %arg0, %arg1
%0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -153,7 +153,7 @@ func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tens
// -----
// CHECK-LABEL: @maximumWithoutBroadcast
func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.maximum %arg0, %arg1
// CHECK: mhlo.maximum %arg0, %arg1
%0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -161,7 +161,7 @@ func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// -----
// CHECK-LABEL: @minimumWithoutBroadcast
func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.minimum %arg0, %arg1
// CHECK: mhlo.minimum %arg0, %arg1
%0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -169,7 +169,7 @@ func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// -----
// CHECK-LABEL: @multiplyWithoutBroadcast
func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.multiply %arg0, %arg1
// CHECK: mhlo.multiply %arg0, %arg1
%0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -177,7 +177,7 @@ func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te
// -----
// CHECK-LABEL: @orWithoutBroadcast
func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.or %arg0, %arg1
// CHECK: mhlo.or %arg0, %arg1
%0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
@ -185,7 +185,7 @@ func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi
// -----
// CHECK-LABEL: @powerWithoutBroadcast
func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.power %arg0, %arg1
// CHECK: mhlo.power %arg0, %arg1
%0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -193,7 +193,7 @@ func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
// -----
// CHECK-LABEL: @remainderWithoutBroadcast
func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.remainder %arg0, %arg1
// CHECK: mhlo.remainder %arg0, %arg1
%0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -201,7 +201,7 @@ func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> t
// -----
// CHECK-LABEL: @shift_leftWithoutBroadcast
func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_left %arg0, %arg1
// CHECK: mhlo.shift_left %arg0, %arg1
%0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -209,7 +209,7 @@ func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) ->
// -----
// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast
func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1
// CHECK: mhlo.shift_right_arithmetic %arg0, %arg1
%0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -217,7 +217,7 @@ func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor
// -----
// CHECK-LABEL: @shift_right_logicalWithoutBroadcast
func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_right_logical %arg0, %arg1
// CHECK: mhlo.shift_right_logical %arg0, %arg1
%0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -225,7 +225,7 @@ func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4x
// -----
// CHECK-LABEL: @subWithoutBroadcast
func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.subtract %arg0, %arg1
// CHECK: mhlo.subtract %arg0, %arg1
%0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -233,7 +233,7 @@ func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<
// -----
// CHECK-LABEL: @xorWithoutBroadcast
func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.xor %arg0, %arg1
// CHECK: mhlo.xor %arg0, %arg1
%0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @single_operand
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> {
%0 = "xla_hlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
%0 = "mhlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<1x2xf32>
}

View File

@ -5,7 +5,7 @@
// CHECK-LABEL: func @same_type
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @same_type(%arg: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.convert"(%arg) : (tensor<f32>) -> tensor<f32>
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<f32>
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<f32>
}
@ -15,8 +15,8 @@ func @same_type(%arg: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: func @int_widening
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i64>
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i64>
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i64>
}
@ -26,8 +26,8 @@ func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
// CHECK-LABEL: func @int_narrowing
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i16>
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<i16>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i16>
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i16>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i16>
}
@ -37,8 +37,8 @@ func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
// CHECK-LABEL: func @float_int
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @float_int(%arg: tensor<f32>) -> tensor<i32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<f32>) -> tensor<i32>
%0 = "xla_hlo.convert"(%arg) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<f32>) -> tensor<i32>
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i32>
}
@ -48,8 +48,8 @@ func @float_int(%arg: tensor<f32>) -> tensor<i32> {
// CHECK-LABEL: func @int_float
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_float(%arg: tensor<i32>) -> tensor<f32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<f32>
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<f32>
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<f32>
}
@ -59,8 +59,8 @@ func @int_float(%arg: tensor<i32>) -> tensor<f32> {
// CHECK-LABEL: func @high_rank_tensor
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
%0 = "xla_hlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
%0 = "mhlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<2x3xf32>
}
@ -70,9 +70,9 @@ func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
// CHECK-LABEL: func @const_same_type
func @const_same_type() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
@ -81,9 +81,9 @@ func @const_same_type() -> tensor<i32> {
// CHECK-LABEL: func @const_float_int
func @const_float_int() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42.0> : tensor<f32>
%0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42.0> : tensor<f32>
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
@ -92,9 +92,9 @@ func @const_float_int() -> tensor<i32> {
// CHECK-LABEL: func @const_int_float
func @const_int_float() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor<f32>
%cst = xla_hlo.constant dense<4> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<f32>
%cst = mhlo.constant dense<4> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
@ -103,9 +103,9 @@ func @const_int_float() -> tensor<f32> {
// CHECK-LABEL: func @const_negative_int_float
func @const_negative_int_float() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
%cst = xla_hlo.constant dense<-4> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
%cst = mhlo.constant dense<-4> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
@ -114,9 +114,9 @@ func @const_negative_int_float() -> tensor<f32> {
// CHECK-LABEL: func @const_int_bf16
func @const_int_bf16() -> tensor<bf16> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
%cst = xla_hlo.constant dense<4> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
%cst = mhlo.constant dense<4> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<bf16>
}
@ -125,9 +125,9 @@ func @const_int_bf16() -> tensor<bf16> {
// CHECK-LABEL: func @const_bf16_int
func @const_bf16_int() -> tensor<i16> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i16>
%cst = xla_hlo.constant dense<42.0> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
%cst = mhlo.constant dense<42.0> : tensor<bf16>
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i16>
}
@ -136,9 +136,9 @@ func @const_bf16_int() -> tensor<i16> {
// CHECK-LABEL: func @const_int_narrowing
func @const_int_narrowing() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42> : tensor<i64>
%0 = "xla_hlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42> : tensor<i64>
%0 = "mhlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
@ -147,9 +147,9 @@ func @const_int_narrowing() -> tensor<i32> {
// CHECK-LABEL: func @const_int_widening
func @const_int_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i64>
%cst = xla_hlo.constant dense<42> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
%cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
@ -158,9 +158,9 @@ func @const_int_widening() -> tensor<i64> {
// CHECK-LABEL: func @const_negative_int_widening
func @const_negative_int_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-42> : tensor<i64>
%cst = xla_hlo.constant dense<-42> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
%cst = mhlo.constant dense<-42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
@ -169,9 +169,9 @@ func @const_negative_int_widening() -> tensor<i64> {
// CHECK-LABEL: func @const_float_narrowing
func @const_float_narrowing() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
%cst = xla_hlo.constant dense<4.2> : tensor<f64>
%0 = "xla_hlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
%cst = mhlo.constant dense<4.2> : tensor<f64>
%0 = "mhlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
@ -180,9 +180,9 @@ func @const_float_narrowing() -> tensor<f32> {
// CHECK-LABEL: func @const_f32_bf16
func @const_f32_bf16() -> tensor<bf16> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
%cst = xla_hlo.constant dense<42.0> : tensor<f32>
%0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
%cst = mhlo.constant dense<42.0> : tensor<f32>
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<bf16>
}
@ -191,9 +191,9 @@ func @const_f32_bf16() -> tensor<bf16> {
// CHECK-LABEL: func @const_bf16_f64
func @const_bf16_f64() -> tensor<f64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor<f64>
%cst = xla_hlo.constant dense<4.2> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.187500e+00> : tensor<f64>
%cst = mhlo.constant dense<4.2> : tensor<bf16>
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f64>
}
@ -202,9 +202,9 @@ func @const_bf16_f64() -> tensor<f64> {
// CHECK-LABEL: func @const_bf16_int
func @const_bf16_int() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i64>
%cst = xla_hlo.constant dense<42.0> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
%cst = mhlo.constant dense<42.0> : tensor<bf16>
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
@ -214,11 +214,11 @@ func @const_bf16_int() -> tensor<i64> {
// CHECK-LABEL: func @const_high_rank_tensor
func @const_high_rank_tensor() -> tensor<2x3xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
// CHECK-SAME: ]> : tensor<2x3xi32>
%cst = xla_hlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
%0 = "xla_hlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
%cst = mhlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
%0 = "mhlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2x3xi32>
}

View File

@ -4,7 +4,7 @@
// BOTH-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
%tensor_result = "mhlo.exponential"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
@ -28,11 +28,11 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// BOTH-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
%3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
%4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32>
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
%1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
%2 = mhlo.add %arg0, %1 : tensor<4xf32>
%3 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
%4 = mhlo.subtract %arg1, %3 : tensor<4xf32>
%5 = mhlo.multiply %2, %4 : tensor<4xf32>
return %5 : tensor<4xf32>
}
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
@ -65,12 +65,12 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
@ -86,7 +86,7 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
// BOTH-LABEL: func @copy
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.copy"(%tensor_operand)
%tensor_result = "mhlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -98,7 +98,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @exp
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
%tensor_result = "mhlo.exponential"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -110,7 +110,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @log
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.log"(%tensor_operand)
%tensor_result = "mhlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -125,7 +125,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%tensor_pred = tensor_load %pred : memref<2x2xi1>
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -138,7 +138,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs)
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
@ -151,7 +151,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
// BOTH-LABEL: func @broadcast
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_operand = tensor_load %operand : memref<5xf32>
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand)
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32>
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
@ -170,7 +170,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32>
%shape = call @external_func() : () -> tensor<3xi64>
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// BOTH: %[[SHAPE:.*]] = call @external_func()
@ -226,7 +226,7 @@ func @complex(%real: memref<2x2xf32>,
%result: memref<2x2xcomplex<f32>>) {
%tensor_real = tensor_load %real : memref<2x2xf32>
%tensor_imag = tensor_load %imag : memref<2x2xf32>
%tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag)
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
@ -238,7 +238,7 @@ func @complex(%real: memref<2x2xf32>,
// BOTH-LABEL: func @real
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.real"(%tensor_operand)
%tensor_result = "mhlo.real"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -250,7 +250,7 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @imag
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.imag"(%tensor_operand)
%tensor_result = "mhlo.imag"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -261,7 +261,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @iota
func @iota(%result: memref<10xi32>) {
%tensor_result = "xla_hlo.iota"()
%tensor_result = "mhlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32>
@ -273,7 +273,7 @@ func @iota(%result: memref<10xi32>) {
// BOTH-LABEL: func @abs
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.abs"(%tensor_operand)
%tensor_result = "mhlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -285,7 +285,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.ceil"(%tensor_operand)
%tensor_result = "mhlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -297,7 +297,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @convert
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.convert"(%tensor_operand)
%tensor_result = "mhlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// BOTH-NOT: tensor_store
@ -310,7 +310,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @cos
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.cosine"(%tensor_operand)
%tensor_result = "mhlo.cosine"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -322,7 +322,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @neg
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.negate"(%tensor_operand)
%tensor_result = "mhlo.negate"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -334,7 +334,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.rsqrt"(%tensor_operand)
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -346,7 +346,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @sign
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sign"(%tensor_operand)
%tensor_result = "mhlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -358,7 +358,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @sqrt
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sqrt"(%tensor_operand)
%tensor_result = "mhlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -370,7 +370,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @tanh
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.tanh"(%tensor_operand)
%tensor_result = "mhlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -383,7 +383,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs)
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -395,7 +395,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// Dynamic shape binary element-wise operation.
// BOTH-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
%result = "xla_hlo.add"(%lhs, %rhs)
%result = "mhlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
@ -420,7 +420,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
// Dynamic shape unary element-wise operation.
// BOTH-LABEL: func @tanh_dyn
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
%result = "xla_hlo.tanh"(%arg0)
%result = "mhlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
@ -448,7 +448,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "xla_hlo.dot"(%arg0, %arg0)
%dot = "mhlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
// ESC: return %[[ALLOC]]
@ -466,7 +466,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
// BOTH-SAME: window_strides = dense<[2, 1]>
%out = "xla_hlo.convolution"(%filter, %input) {
%out = "mhlo.convolution"(%filter, %input) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,

View File

@ -10,7 +10,7 @@ func @float_add(%lhs: tensor<2x2xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = addf %[[ARG0]], %[[ARG1]]
// CHECK: linalg.yield %[[RESULT]]
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -22,7 +22,7 @@ func @integer_add(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: addi
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -34,7 +34,7 @@ func @float_mul(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: mulf
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
%0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -46,7 +46,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: muli
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
%0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -58,7 +58,7 @@ func @float_remainder(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: remf
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -70,7 +70,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: remi_signed
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -79,7 +79,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
// CHECK-LABEL: func @float_rsqrt
func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
%tensor_result = "xla_hlo.rsqrt"(%operand)
%tensor_result = "mhlo.rsqrt"(%operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: linalg.generic
// CHECK: rsqrt
@ -93,7 +93,7 @@ func @float_sub(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: subf
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -105,7 +105,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: subi
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -116,7 +116,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: absf
%0 = "xla_hlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -126,7 +126,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: exp
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -136,7 +136,7 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: log
%0 = "xla_hlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -146,7 +146,7 @@ func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: ceilf
%0 = "xla_hlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -156,7 +156,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: negf
%0 = "xla_hlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -166,7 +166,7 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: tanh
%0 = "xla_hlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -177,7 +177,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: and
%0 = "xla_hlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
%0 = "mhlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -187,7 +187,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
// CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1>
}
@ -201,7 +201,7 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
// CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
: (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>)
return %0 : tensor<2x2xi1>
}
@ -216,7 +216,7 @@ func @int_cmp(%lhs: tensor<2x2xi32>,
func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: cos
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -226,7 +226,7 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: sin
%0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -235,7 +235,7 @@ func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @copy
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
%0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
%0 = "mhlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
return %0 : tensor<2x4x8xf32>
}
// CHECK: return [[ARG]] : tensor<2x4x8xf32>
@ -245,7 +245,7 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
// CHECK-LABEL: func @select
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = "xla_hlo.select"(%pred, %lhs, %rhs)
%0 = "mhlo.select"(%pred, %lhs, %rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
return %0 : tensor<2x2xf32>
}
@ -260,7 +260,7 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
return %0: tensor<4x2x1xf32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
@ -273,7 +273,7 @@ func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
// CHECK-LABEL: func @broadcast
func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
return %0: tensor<4x2x1x4x?x16xf32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
@ -286,7 +286,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-LABEL: func @broadcast_in_dim
func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand)
%0 = "mhlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
: (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32>
return %0 : tensor<7x10x6x4x5xf32>
@ -302,7 +302,7 @@ func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one
func @broadcast_in_dim_with_one_to_one(
%operand: tensor<1xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand)
%0 = "mhlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[0]> : tensor<1xi64>}
: (tensor<1xf32>) -> tensor<1x5xf32>
return %0 : tensor<1x5xf32>
@ -317,7 +317,7 @@ func @broadcast_in_dim_with_one_to_one(
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand)
%0 = "mhlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[]> : tensor<0xi64>}
: (tensor<f32>) -> tensor<7x10x6xf32>
return %0 : tensor<7x10x6xf32>
@ -332,7 +332,7 @@ func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @transpose
func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
return %0 : tensor<3x2x5x9xi32>
}
@ -344,7 +344,7 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
%0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
@ -355,7 +355,7 @@ func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
@ -366,7 +366,7 @@ func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
return %0 : tensor<12x1x42x1xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
@ -375,7 +375,7 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
// CHECK-LABEL: func @minf
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = "xla_hlo.minimum"(%lhs, %rhs)
%0 = "mhlo.minimum"(%lhs, %rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -389,7 +389,7 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @maxi
func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%0 = "xla_hlo.maximum"(%lhs, %rhs)
%0 = "mhlo.maximum"(%lhs, %rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -404,7 +404,7 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()>
// CHECK-LABEL: func @add_scalar
func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK: linalg.generic
@ -417,7 +417,7 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
func @reshape_collapse_single_dim
(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
return %0 : tensor<1x784xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
@ -428,7 +428,7 @@ func @reshape_collapse_single_dim
// -----
func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
return %0 : tensor<2x4x3xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
@ -440,7 +440,7 @@ func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
// -----
func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
return %0 : tensor<2x4x2xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)>
@ -451,7 +451,7 @@ func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
// -----
func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
return %0 : tensor<1x4x2xf32>
}
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
@ -462,7 +462,7 @@ func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
func @reshape_multiple_collapse
(%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
return %0 : tensor<1x4x5x6xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
@ -476,7 +476,7 @@ func @reshape_multiple_collapse
// CHECK-LABEL: func @convert_i32_to_f32
func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32>
}
// CHECK: linalg.generic
@ -488,7 +488,7 @@ func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @convert_i16_to_i32
func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
%result = "mhlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK: linalg.generic
@ -500,7 +500,7 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
// CHECK-LABEL: func @convert_i32_to_i16
func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
return %result : tensor<2x2xi16>
}
// CHECK: linalg.generic
@ -512,7 +512,7 @@ func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
// CHECK-LABEL: func @convert_f32_to_f64
func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
return %result : tensor<2x2xf64>
}
// CHECK: linalg.generic
@ -524,7 +524,7 @@ func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
// CHECK-LABEL: func @convert_f64_to_f32
func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
%result = "mhlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32>
}
// CHECK: linalg.generic
@ -536,7 +536,7 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @convert_f32_to_i32
func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK: linalg.generic
@ -550,7 +550,7 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse
func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
%result = "xla_hlo.reverse"(%input) {
%result = "mhlo.reverse"(%input) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %result : tensor<2x3xf32>

View File

@ -1,28 +1,28 @@
// RUN: mlir-hlo-opt %s -inline | FileCheck %s
// Test case: Basic test of inlining into xla_hlo.while.
// Test case: Basic test of inlining into mhlo.while.
// CHECK-LABEL: func @caller
// CHECK: "xla_hlo.while"{{.*}}( {
// CHECK: "mhlo.while"{{.*}}( {
// CHECK: }, {
// CHECK: "xla_hlo.exponential"
// CHECK: "mhlo.exponential"
// CHECK: })
// CHECK-LABEL: func @callee
func @caller(%arg0: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> {
%0 = "xla_hlo.while"(%arg0) ( {
%0 = "mhlo.while"(%arg0) ( {
^entry(%unused: tensor<f32>):
"xla_hlo.return"(%pred) : (tensor<i1>) -> ()
"mhlo.return"(%pred) : (tensor<i1>) -> ()
}, {
^entry(%0: tensor<f32>):
%1 = call @callee(%0) : (tensor<f32>) -> (tensor<f32>)
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
"mhlo.return"(%1) : (tensor<f32>) -> ()
} ) : (tensor<f32>) -> (tensor<f32>)
return %0 : tensor<f32>
}
func @callee(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
%0 = "mhlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}

View File

@ -4,21 +4,21 @@
func @while(%arg0: tensor<i64>) -> tensor<i64> {
//CHECK: br ^bb1(%arg0 : tensor<i64>)
//CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>):
//CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]])
//CHECK: [[VAL1:%.+]] = "mhlo.compare"([[VAL0]], [[VAL0]])
//CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor<i1>
//CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor<i64>), ^bb3([[VAL0]] : tensor<i64>)
//CHECK: ^bb2([[VAL3:%.+]]: tensor<i64>):
//CHECK: [[VAL4:%.+]] = xla_hlo.add [[VAL3]], [[VAL3]]
//CHECK: [[VAL4:%.+]] = mhlo.add [[VAL3]], [[VAL3]]
//CHECK: br ^bb1([[VAL4]] : tensor<i64>)
//CHECK: ^bb3([[VAL5:%.+]]: tensor<i64>):
%0 = "xla_hlo.while"(%arg0) ( {
%0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
%1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
"mhlo.return"(%1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
// CHECK-NEXT: return [[VAL5]]
@ -30,27 +30,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor<f32>
%cst = constant dense<1.000000e+01> : tensor<f32>
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1>
// CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>)
%1 = "xla_hlo.if"(%0, %arg0, %arg0) ( {
%1 = "mhlo.if"(%0, %arg0, %arg0) ( {
^bb0(%arg1: tensor<f32>):
// CHECK: ^bb1([[VAL2:%.+]]: tensor<f32>):
// CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
// CHECK: [[VAL3:%.+]] = "mhlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^bb3([[VAL3]] : tensor<f32>)
%2 = "xla_hlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = "mhlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}, {
^bb0(%arg1: tensor<f32>):
// CHECK: ^bb2([[VAL4:%.+]]: tensor<f32>):
// CHECK: [[VAL5:%.+]] = "xla_hlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
// CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^bb3([[VAL5]] : tensor<f32>)
%2 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: ^bb3([[VAL6:%.+]]: tensor<f32>):
@ -62,27 +62,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
// CHECK: %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %2 = extract_element %1[] : tensor<i1>
// CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
// CHECK: ^[[BODY_ENTRY]](%3: tensor<i64>):
// CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor<i64>)
// CHECK: ^[[BODY_SUCC]](%4: tensor<i64>):
// CHECK: %5 = xla_hlo.add %4, %4 : tensor<i64>
// CHECK: %5 = mhlo.add %4, %4 : tensor<i64>
// CHECK: br ^[[COND_ENTRY]](%5 : tensor<i64>)
// CHECK: ^[[EXIT]](%6: tensor<i64>):
// CHECK: return %6 : tensor<i64>
// CHECK: }
%0 = "xla_hlo.while"(%arg0) ( {
%0 = "mhlo.while"(%arg0) ( {
^cond_entry(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^body_entry(%arg1: tensor<i64>):
br ^body_succ(%arg1: tensor<i64>)
^body_succ(%0: tensor<i64>):
%1 = xla_hlo.add %0, %0 : tensor<i64>
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
%1 = mhlo.add %0, %0 : tensor<i64>
"mhlo.return"(%1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
@ -94,7 +94,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
// CHECK: br ^[[COND_SUCC:.+]](%0 : tensor<i64>)
// CHECK: ^[[COND_SUCC]](%1: tensor<i64>):
// CHECK: %2 = "xla_hlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %3 = extract_element %2[] : tensor<i1>
// CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
// CHECK: ^[[BODY_ENTRY]](%4: tensor<i64>):
@ -102,15 +102,15 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: ^[[EXIT]](%5: tensor<i64>):
// CHECK: return %5 : tensor<i64>
// CHECK: }
%0 = "xla_hlo.while"(%arg0) ( {
%0 = "mhlo.while"(%arg0) ( {
^cond_entry(%arg1: tensor<i64>):
br ^cond_succ(%arg1: tensor<i64>)
^cond_succ(%0: tensor<i64>):
%1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
%1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^body_entry(%arg1: tensor<i64>):
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
@ -123,24 +123,24 @@ func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %
// CHECK: ^[[THEN_ENTRY]](%1: tensor<f32>):
// CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor<f32>)
// CHECK: ^[[THEN_SUCC]](%2: tensor<f32>):
// CHECK: %3 = "xla_hlo.log"(%2) : (tensor<f32>) -> tensor<f32>
// CHECK: %3 = "mhlo.log"(%2) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^[[EXIT:.+]](%3 : tensor<f32>)
// CHECK: ^[[ELSE_ENTRY]](%4: tensor<f32>):
// CHECK: %5 = "xla_hlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
// CHECK: %5 = "mhlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^[[EXIT]](%5 : tensor<f32>)
// CHECK: ^[[EXIT]](%6: tensor<f32>):
// CHECK: return %6 : tensor<f32>
// CHECK: }
%1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( {
%1 = "mhlo.if"(%pred, %arg0, %arg1) ( {
^then_entry(%arg2: tensor<f32>):
br ^then_succ(%arg2: tensor<f32>)
^then_succ(%0: tensor<f32>):
%2 = "xla_hlo.log"(%0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = "mhlo.log"(%0) : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}, {
^else_entry(%arg2: tensor<f32>):
%2 = "xla_hlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = "mhlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}

View File

@ -3,19 +3,19 @@
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32>
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32>
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32>
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32>
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %4 : tensor<4xf32>
return %4 : tensor<4xf32>
@ -24,19 +24,19 @@ func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf
// CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32>
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32>
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32>
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32>
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32>
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: return %4 : tensor<4xi32>
return %4 : tensor<4xi32>
@ -45,17 +45,17 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32
// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32>
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32>
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32>
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32>
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32>
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
}
@ -63,28 +63,28 @@ func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi
// CHECK-LABEL: func @compare_float
func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
// CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32>
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32>
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32>
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32>
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32>
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32>
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
}
// CHECK-LABEL: func @int_constant
func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<i32>
%0 = "xla_hlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
%0 = "mhlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32>
%1 = "xla_hlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
%1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32>
%2 = "xla_hlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
%2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
return %0, %1, %2: tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
}
@ -92,11 +92,11 @@ func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
// CHECK-LABEL: func @float_constant
func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<f32>
%0 = "xla_hlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
%0 = "mhlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32>
%1 = "xla_hlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
%1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32>
%2 = "xla_hlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
%2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
}
@ -105,7 +105,7 @@ func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
func @iota.const.1() -> tensor<4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -113,7 +113,7 @@ func @iota.const.1() -> tensor<4xi32> {
// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
func @iota.const.2() -> tensor<2x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
return %0 : tensor<2x4xi32>
}
@ -121,7 +121,7 @@ func @iota.const.2() -> tensor<2x4xi32> {
// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
func @iota.const.3() -> tensor<2x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
return %0 : tensor<2x4xi32>
}
@ -129,7 +129,7 @@ func @iota.const.3() -> tensor<2x4xi32> {
// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
func @iota.const.4() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32>
}
@ -137,7 +137,7 @@ func @iota.const.4() -> tensor<2x3x4xi32> {
// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
func @iota.const.5() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32>
}
@ -145,7 +145,7 @@ func @iota.const.5() -> tensor<2x3x4xi32> {
// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
func @iota.const.6() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32>
}
@ -153,7 +153,7 @@ func @iota.const.6() -> tensor<2x3x4xi32> {
// CHECK-LABEL: func @iota.const.f32
func @iota.const.f32() -> tensor<4xf32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
// CHECK-NEXT: return %[[CST]] : tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -161,7 +161,7 @@ func @iota.const.f32() -> tensor<4xf32> {
// CHECK-LABEL: func @iota.const.f64
func @iota.const.f64() -> tensor<4xf64> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
// CHECK-NEXT: return %[[CST]] : tensor<4xf64>
return %0 : tensor<4xf64>
}
@ -169,7 +169,7 @@ func @iota.const.f64() -> tensor<4xf64> {
// CHECK-LABEL: func @iota.const.bf16
func @iota.const.bf16() -> tensor<4xbf16> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
// CHECK-NEXT: return %[[CST]] : tensor<4xbf16>
return %0 : tensor<4xbf16>
}
@ -178,8 +178,8 @@ func @iota.const.bf16() -> tensor<4xbf16> {
func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32>
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
// CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f32>>
return %0 : tensor<4xcomplex<f32>>
}
@ -188,8 +188,8 @@ func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
func @iota.const.complex.f64() -> tensor<4xcomplex<f64>> {
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64>
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
// CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f64>>
return %0 : tensor<4xcomplex<f64>>
}

View File

@ -396,9 +396,9 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
"xla_lhlo.fusion"() ( {
%0 = tensor_load %input1 : memref<10xf32>
%1 = tensor_load %input2 : memref<10xf32>
%2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%3 = tensor_load %input3 : memref<10xf32>
%4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
tensor_store %4, %out : memref<10xf32>
"xla_lhlo.terminator"() : () -> ()
} ) : () -> ()
@ -803,15 +803,15 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
"mhlo.return"(%max) : (tensor<f32>) -> ()
})
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
"mhlo.return"(%max) : (tensor<f32>) -> ()
})
{
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
@ -958,8 +958,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
%add = mhlo.add %lhs, %rhs : tensor<f32>
"mhlo.return"(%add) : (tensor<f32>) -> ()
}) {
scatter_dimension_numbers = {
update_window_dims = dense<[1]> : tensor<1xi64>,
@ -979,8 +979,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = xla_hlo.add %a, %b : tensor<f32>
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
%c = mhlo.add %a, %b : tensor<f32>
"mhlo.return"(%c) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
return
}
@ -991,8 +991,8 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = xla_hlo.add %a, %b : tensor<f32>
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
%c = mhlo.add %a, %b : tensor<f32>
"mhlo.return"(%c) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> ()
return
}
@ -1012,8 +1012,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
return
}
@ -1025,8 +1025,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
return
}
@ -1038,8 +1038,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
return
}

View File

@ -2,14 +2,14 @@
// CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
%4 = "xla_hlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
%4 = "mhlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@ -17,14 +17,14 @@ func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
// CHECK-LABEL: @add_unranked
func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
%4 = "xla_hlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
%4 = "mhlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
@ -32,14 +32,14 @@ func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
// CHECK-LABEL: @sub
func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3
%4 = "xla_hlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
%4 = "mhlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@ -47,14 +47,14 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
// CHECK-LABEL: @sub_unranked
func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3
%4 = "xla_hlo.subtract"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
%4 = "mhlo.subtract"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
@ -62,18 +62,18 @@ func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
// CHECK-LABEL: @mul
func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
%4 = "xla_hlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
%4 = "mhlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32>
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@ -81,18 +81,18 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
// CHECK-LABEL: @mul_unranked
func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
%4 = "xla_hlo.multiply"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
%4 = "mhlo.multiply"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32>
return %5, %6 : tensor<*xf32>, tensor<*xf32>
@ -100,36 +100,36 @@ func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
// CHECK-LABEL: @div
func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]]
%4 = "xla_hlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
%4 = "mhlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@ -139,36 +139,36 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
// CHECK-LABEL: @div_unranked
func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]]
%4 = "xla_hlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
%4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
@ -176,14 +176,14 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
// CHECK-LABEL: @abs
func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg0
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg1
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]])
%1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg0
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1
// CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]])
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]]
return %2 : tensor<2xf32>
@ -191,16 +191,16 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
// CHECK-LABEL: @exp
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]]
%1 = "xla_hlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
%1 = "mhlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%3 = "mhlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
return %2, %3 : tensor<2xf32>, tensor<2xf32>
@ -208,16 +208,16 @@ func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tenso
// CHECK-LABEL: @exp_unranked
func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]]
%1 = "xla_hlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
%1 = "mhlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%3 = "mhlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
return %2, %3 : tensor<*xf32>, tensor<*xf32>

View File

@ -2,10 +2,10 @@
// CHECK-LABEL: @testDebatch1
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
// CHECK-DAG: [[R0:%.+]] = "xla_hlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
// CHECK-DAG: [[R1:%.+]] = "xla_hlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
// CHECK: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
// CHECK-DAG: [[R0:%.+]] = "mhlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
// CHECK-DAG: [[R1:%.+]] = "mhlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
// CHECK: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
return %0 : tensor<1x1x3xf32>
}
@ -14,13 +14,13 @@ func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1
// CHECK-LABEL: @testDebatch2
func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> {
// CHECK-DAG: [[R0:%.+]] = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
// CHECK-DAG: [[R1:%.+]] = "xla_hlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
// CHECK-DAG: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
// CHECK-DAG: [[R3:%.+]] = "xla_hlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
// CHECK: [[R4:%.+]] = "xla_hlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
// CHECK-DAG: [[R0:%.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
// CHECK-DAG: [[R1:%.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
// CHECK-DAG: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
// CHECK-DAG: [[R3:%.+]] = "mhlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
// CHECK: [[R4:%.+]] = "mhlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
return %0 : tensor<3x1x1xf32>
}
@ -28,8 +28,8 @@ func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3
// CHECK-LABEL: @testBatchPassthrough
func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> {
// CHECK-NEXT: "xla_hlo.dot_general"(%arg0, %arg1)
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32>
// CHECK-NEXT: "mhlo.dot_general"(%arg0, %arg1)
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32>
return %0 : tensor<3x2x1xf32>
}

View File

@ -3,9 +3,9 @@
// CHECK-LABEL: @clampBroadcast
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)
func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>) -> tensor<4xf32> {
// CHECK-DAG: %[[MIN_BC:.+]] = "xla_hlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK-DAG: %[[MAX_BC:.+]] = "xla_hlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK: "xla_hlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-DAG: %[[MIN_BC:.+]] = "mhlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK-DAG: %[[MAX_BC:.+]] = "mhlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK: "mhlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "mhlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}

File diff suppressed because it is too large Load Diff

View File

@ -4,11 +4,11 @@
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
// CHECK: return %[[ARG0]]
func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%2 = "xla_hlo.reduce"(%arg0, %0) ( {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%2 = "mhlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%4 = xla_hlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
%4 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
return %2 : tensor<4x8xf32>
}

View File

@ -2,9 +2,9 @@
// CHECK-LABEL: func @const_fold_collapse_to_scalar
func @const_fold_collapse_to_scalar() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42> : tensor<1x1xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42> : tensor<1x1xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
@ -13,9 +13,9 @@ func @const_fold_collapse_to_scalar() -> tensor<i32> {
// CHECK-LABEL: func @const_fold_collapse_to_tensor
func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<2xi32>
%cst = xla_hlo.constant dense<42> : tensor<1x2xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<2xi32>
%cst = mhlo.constant dense<42> : tensor<1x2xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2xi32>
}
@ -24,9 +24,9 @@ func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
// CHECK-LABEL: func @const_fold_expand
func @const_fold_expand() -> tensor<1xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<1xi32>
%cst = xla_hlo.constant dense<42> : tensor<i32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<1xi32>
%cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<1xi32>
}
@ -35,9 +35,9 @@ func @const_fold_expand() -> tensor<1xi32> {
// CHECK-LABEL: func @const_fold_nontrivial
func @const_fold_nontrivial() -> tensor<16xi64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64>
%cst = xla_hlo.constant dense<42> : tensor<4x4xi64>
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xi64>
}
@ -46,9 +46,9 @@ func @const_fold_nontrivial() -> tensor<16xi64> {
// CHECK-LABEL: func @const_fold_flatten
func @const_fold_flatten() -> tensor<16xi64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64>
%cst = xla_hlo.constant dense<42> : tensor<4x4xi64>
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xi64>
}
@ -57,9 +57,9 @@ func @const_fold_flatten() -> tensor<16xi64> {
// CHECK-LABEL: func @const_fold_6
func @const_fold_6() -> tensor<6xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%cst = xla_hlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<6xi32>
}
@ -68,11 +68,11 @@ func @const_fold_6() -> tensor<6xi32> {
// CHECK-LABEL: func @const_fold_same_shape
func @const_fold_same_shape() -> tensor<2x3xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
// CHECK-SAME: ]> : tensor<2x3xi32>
%cst = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
%cst = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2x3xi32>
}
@ -81,9 +81,9 @@ func @const_fold_same_shape() -> tensor<2x3xi32> {
// CHECK-LABEL: func @const_fold_float
func @const_fold_float() -> tensor<16xf64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
%cst = xla_hlo.constant dense<4.2> : tensor<4x4xf64>
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
%cst = mhlo.constant dense<4.2> : tensor<4x4xf64>
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xf64>
}
@ -94,7 +94,7 @@ func @const_fold_float() -> tensor<16xf64> {
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-NEXT: return [[ARG]]
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
}
@ -103,10 +103,10 @@ func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-LABEL: func @non_const_chained_reshape
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) {
// CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed
}
@ -115,9 +115,9 @@ func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, ten
// CHECK-LABEL: func @non_const_chained_reshape_unused_parent
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: return [[RES]]
return %1 : tensor<6xi32>
}
@ -127,8 +127,8 @@ func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<
// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[ARG]]
return %1 : tensor<2x3xi32>
}
@ -138,12 +138,12 @@ func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2
// CHECK-LABEL: func @non_const_many_chained_reshapes
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
%2 = "xla_hlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
%3 = "xla_hlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
%4 = "xla_hlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
%2 = "mhlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
%3 = "mhlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
%4 = "mhlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
// CHECK-NEXT: return [[RES]]
return %4 : tensor<1x2x4x3xi32>
}

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @noop
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>)
func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
%0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
%0 = "mhlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
// CHECK: return %[[ARG0]]
return %0 : tensor<1x2xf32>
}

View File

@ -4,27 +4,27 @@
// CHECK-LABEL: func @sink_const_to_while
func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-NEXT: xla_hlo.while
%c0 = xla_hlo.constant dense<1> : tensor<i64>
%c1 = xla_hlo.constant dense<2> : tensor<i64>
%0 = "xla_hlo.while"(%arg0) ( {
// CHECK-NEXT: mhlo.while
%c0 = mhlo.constant dense<1> : tensor<i64>
%c1 = mhlo.constant dense<2> : tensor<i64>
%0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
// CHECK: %[[ARG1A:.+]]: tensor<i64>
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
// CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]])
%1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
// CHECK: "mhlo.compare"(%[[C0]], %[[ARG1A]])
%1 = "mhlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
// CHECK: %[[ARG1B:.+]]: tensor<i64>
// CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
// CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]]
%2 = xla_hlo.add %arg1, %arg1 : tensor<i64>
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]]
%3 = xla_hlo.add %c1, %2 : tensor<i64>
// CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]]
%4 = xla_hlo.add %c1, %3 : tensor<i64>
"xla_hlo.return"(%4) : (tensor<i64>) -> ()
// CHECK-DAG: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
// CHECK-DAG: %[[ADD0:.+]] = mhlo.add %[[ARG1B]], %[[ARG1B]]
%2 = mhlo.add %arg1, %arg1 : tensor<i64>
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], %[[ADD0]]
%3 = mhlo.add %c1, %2 : tensor<i64>
// CHECK: %[[ADD2:.+]] = mhlo.add %[[C1]], %[[ADD1]]
%4 = mhlo.add %c1, %3 : tensor<i64>
"mhlo.return"(%4) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
}
@ -33,28 +33,28 @@ func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-LABEL: func @sink_const_to_conditional
func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> {
%c0 = xla_hlo.constant dense<1> : tensor<i64>
%c1 = xla_hlo.constant dense<2> : tensor<i64>
%0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
%1 = "xla_hlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
// CHECK: xla_hlo.if
%2 = "xla_hlo.if"(%0, %1, %1) ( {
%c0 = mhlo.constant dense<1> : tensor<i64>
%c1 = mhlo.constant dense<2> : tensor<i64>
%0 = "mhlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
%1 = "mhlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
// CHECK: mhlo.if
%2 = "mhlo.if"(%0, %1, %1) ( {
^bb0(%arg1: tuple<tensor<i64>>):
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
%3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]],
%4 = xla_hlo.add %c0, %3 : tensor<i64>
%5 = "xla_hlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
"xla_hlo.return"(%5) : (tuple<tensor<i64>>) -> ()
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
%3 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD0:.+]] = mhlo.add %[[C0]],
%4 = mhlo.add %c0, %3 : tensor<i64>
%5 = "mhlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
"mhlo.return"(%5) : (tuple<tensor<i64>>) -> ()
}, {
^bb0(%arg1: tuple<tensor<i64>>):
// CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]],
%7 = xla_hlo.add %c1, %6 : tensor<i64>
%8 = "xla_hlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
"xla_hlo.return"(%8) : (tuple<tensor<i64>>) -> ()
// CHECK: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
%6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]],
%7 = mhlo.add %c1, %6 : tensor<i64>
%8 = "mhlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
"mhlo.return"(%8) : (tuple<tensor<i64>>) -> ()
}) : (tensor<i1>, tuple<tensor<i64>>, tuple<tensor<i64>>) -> tuple<tensor<i64>>
%9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
%9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
return %9 : tensor<i64>
}

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @remove_noop
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
%0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<2x3x9x5xi32>
}
@ -13,8 +13,8 @@ func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
// CHECK-LABEL: func @keep_real_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
return %0 : tensor<3x2x5x9xi32>
}
@ -23,7 +23,7 @@ func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// CHECK-LABEL: func @keep_same_shape_real_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> {
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
return %0 : tensor<4x4xi32>
}

View File

@ -4,7 +4,7 @@
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @fold_access(%arg : tensor<i32>) -> tensor<i32> {
// CHECK-NEXT: return [[ARG]]
%tuple = "xla_hlo.tuple"(%arg) : (tensor<i32>) -> tuple<tensor<i32>>
%element = "xla_hlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
%tuple = "mhlo.tuple"(%arg) : (tensor<i32>) -> tuple<tensor<i32>>
%element = "mhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
return %element : tensor<i32>
}

View File

@ -10,19 +10,19 @@ func @batchNormInference_2D_inner_features(
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<4x256xf32>) {
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor<f32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor<f32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} :
(tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
tensor<256xf32>) -> tensor<4x256xf32>
@ -36,12 +36,12 @@ func @batchNormInference_2D_inner_features(
// the verifier to enforce the rest.
// CHECK-SAME: %[[X:[^:]+]]
// CHECK-SAME: %[[SCALE:[^:]+]]
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
func @batchNormInference_4D_middle_features(
%x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<3x4x256x6xf32>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} :
(tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
tensor<256xf32>) -> tensor<3x4x256x6xf32>
@ -51,12 +51,12 @@ func @batchNormInference_4D_middle_features(
// -----
// CHECK-LABEL: @batchNormInference_f64
// Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f64>
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
func @batchNormInference_f64(
%x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>,
%mean: tensor<256xf64>, %variance: tensor<256xf64>)
-> (tensor<4x256xf64>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
(tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>,
tensor<256xf64>) -> tensor<4x256xf64>
@ -66,12 +66,12 @@ func @batchNormInference_f64(
// -----
// CHECK-LABEL: @batchNormInference_f16
// Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f16>
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f16>
func @batchNormInference_f16(
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
-> (tensor<4x256xf16>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
tensor<256xf16>) -> tensor<4x256xf16>
@ -85,7 +85,7 @@ func @batchNormInference_f16_overflow(
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
-> (tensor<4x256xf16>) {
// expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}}
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 0.00000001 : f32, feature_index = 1 : i64} :
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
tensor<256xf16>) -> tensor<4x256xf16>
@ -108,26 +108,26 @@ func @batchNormInference_dynamic_shape(
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 0.001 : f32, feature_index = 1 : i64} :
(tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
tensor<?xf32>) -> tensor<?x?x?x?xf32>

View File

@ -2,14 +2,14 @@
// CHECK-LABEL: func @multi_outputs_same
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "xla_hlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:2 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.return
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "mhlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:2 = "mhlo.fusion"
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.subtract
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.return
return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32>
}
@ -17,18 +17,18 @@ func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (ten
// CHECK-LABEL: func @multi_outputs_same_2
func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "xla_hlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = "xla_hlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = "xla_hlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:3 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.return
%0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "mhlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = "mhlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = "mhlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:3 = "mhlo.fusion"
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.return
return %2, %3, %4 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
}
@ -36,9 +36,9 @@ func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (t
// CHECK-LABEL: func @multi_outputs_not_sure_same
func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: xla_hlo.fusion
%1 = "xla_hlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%0 = "mhlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: mhlo.fusion
%1 = "mhlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
}
@ -46,25 +46,25 @@ func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
// CHECK-LABEL: func @reduce
func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET0:.*]] = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.return
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET0:.*]] = "mhlo.fusion"
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.subtract
// CHECK-NEXT: mhlo.return
// Currently we do not support fuse arguments and ops without direct producer-consumer
// relationship. Thus Reduce Op should not be fused with above two ops.
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "xla_hlo.reduce"(%arg0, %2) ( {
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "mhlo.reduce"(%arg0, %2) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"mhlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// Above two ops should not be fused since reduce op can not be
// fused with its consumer.
// CHECK-NOT: xla_hlo.fusion
// CHECK-NOT: mhlo.fusion
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
}
@ -73,25 +73,25 @@ func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>
// CHECK-LABEL: func @reduce_2
func @reduce_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "xla_hlo.reduce"(%1, %2) ( {
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "mhlo.reduce"(%1, %2) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"mhlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK: %[[RET0:.*]]:2 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.constant
// CHECK-NEXT: xla_hlo.reduce
// CHECK: xla_hlo.return
// CHECK: %[[RET0:.*]]:2 = "mhlo.fusion"
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.subtract
// CHECK-NEXT: mhlo.constant
// CHECK-NEXT: mhlo.reduce
// CHECK: mhlo.return
// Following op should not be fused with the above ops since reduce op can not be
// fused with its consumer.
// CHECK-NOT: xla_hlo.fusion
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NOT: mhlo.fusion
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
}

View File

@ -9,15 +9,15 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
%num_elements = shape.num_elements %shape
%num_elements_as_index = shape.size_to_index %num_elements
%flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex>
%flat_a = "xla_hlo.dynamic_reshape"(%a, %flat_shape)
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// Apply operation.
%flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
// Restore original shape.
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex>
%b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
%b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
return %b : tensor<*xf32>
@ -33,12 +33,12 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
return %b : tensor<*xf32>
}
@ -48,9 +48,9 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-LABEL: @sqrt_ranked
// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>)
func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
// CHECK-NEXT: return %[[B]] : tensor<3x?xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
%b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
return %b : tensor<3x?xf32>
}
@ -60,9 +60,9 @@ func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
// CHECK-LABEL: @sqrt_static
// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>)
func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: return %[[B]] : tensor<2x3xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %b : tensor<2x3xf32>
}
@ -77,12 +77,12 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = "xla_hlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = xla_hlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESULT]] : tensor<*xf32>
%result = xla_hlo.add %a, %b : tensor<*xf32>
%result = mhlo.add %a, %b : tensor<*xf32>
return %result : tensor<*xf32>
}