Rename `xla_hlo` dialect to `mhlo`
This is part of the current refactoring of the HLO related dialect. `xla_hlo` will be reintroduced in a new form later. PiperOrigin-RevId: 319916753
This commit is contained in:
parent
fa057cc0bc
commit
8900222fed
|
@ -17,12 +17,12 @@ limitations under the License.
|
|||
// These ops are not necessarily orthogonal or optimized for transformation but
|
||||
// for ease of expression in certain cases deemed important for client
|
||||
// libraries (i.e. implicit broadcasting, helper ops, etc).
|
||||
// This dialect is considered to exist in addition to augment the xla_hlo
|
||||
// This dialect is considered to exist in addition to augment the mhlo
|
||||
// dialect for ergonomic needs, not duplicate/replace it.
|
||||
//
|
||||
// The typical use of this dialect is for client libraries to be able to emit
|
||||
// less constrained ops and rely on the conversion framework to lower any
|
||||
// xla_chlo ops to canonical xla_hlo ops.
|
||||
// xla_chlo ops to canonical mhlo ops.
|
||||
//
|
||||
// See: https://www.tensorflow.org/xla/operation_semantics
|
||||
|
||||
|
@ -44,7 +44,7 @@ def HLOClient_Dialect : Dialect {
|
|||
let description = [{
|
||||
This dialect contains ops that align closely with the API surface area
|
||||
of the XlaBuilder C++ API, where such ops have semantics that go beyond
|
||||
what exists in the lower level dialects (such as `xla_hlo`). Essentially,
|
||||
what exists in the lower level dialects (such as `mhlo`). Essentially,
|
||||
whenever the client library uses syntactic sugar or composition
|
||||
of multiple ops for an API call, this dialect tries to model the API call
|
||||
and provide conversion patterns to fully materialize into lower level
|
||||
|
@ -65,7 +65,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
|
|||
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
|
||||
// shape broadcasting.
|
||||
//
|
||||
// These correspond to operations in the xla_hlo dialect without the
|
||||
// These correspond to operations in the mhlo dialect without the
|
||||
// "broadcast_" prefix, except that those ops require same-shaped operands and
|
||||
// results.
|
||||
//
|
||||
|
|
|
@ -37,12 +37,12 @@ class OpBuilder;
|
|||
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc"
|
||||
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
class XlaHloDialect : public Dialect {
|
||||
public:
|
||||
explicit XlaHloDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "xla_hlo"; }
|
||||
static StringRef getDialectNamespace() { return "mhlo"; }
|
||||
|
||||
// Registered hook to materialize a constant operation from a given attribute
|
||||
// value with the desired resultant type.
|
||||
|
@ -82,7 +82,7 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
|
|||
// %1 = index_cast %0 : index to i64
|
||||
// %2 = dim %arg0, 1 : memref<?x?xf32>
|
||||
// %3 = index_cast %2 : index to i64
|
||||
// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3)
|
||||
// %4 = "mhlo.scalars_to_dimension_tensor"(%1, %3)
|
||||
// : (i64, i64) -> tensor<2xi64>
|
||||
//
|
||||
// and returns %4 as the shape value.
|
||||
|
@ -93,7 +93,7 @@ LogicalResult deriveShapeFromFirstOperand(
|
|||
#define GET_OP_CLASSES
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
|
||||
|
||||
} // end namespace xla_hlo
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||
|
|
|
@ -29,8 +29,8 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
|||
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
||||
|
||||
def HLO_Dialect : Dialect {
|
||||
let name = "xla_hlo";
|
||||
let cppNamespace = "xla_hlo";
|
||||
let name = "mhlo";
|
||||
let cppNamespace = "mhlo";
|
||||
}
|
||||
|
||||
class HLO_Op<string mnemonic, list<OpTrait> traits> :
|
||||
|
|
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
template <typename HloOpTy>
|
||||
struct HloToLhloOpImpl {
|
||||
|
@ -33,7 +33,7 @@ using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
|
|||
|
||||
#define MAP_HLO_TO_LHLO(OpName) \
|
||||
template <> \
|
||||
struct HloToLhloOpImpl<xla_hlo::OpName> { \
|
||||
struct HloToLhloOpImpl<mhlo::OpName> { \
|
||||
using Type = xla_lhlo::OpName; \
|
||||
}
|
||||
|
||||
|
@ -74,7 +74,7 @@ MAP_HLO_TO_LHLO(TanhOp);
|
|||
|
||||
#undef MAP_HLO_TO_LHLO
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
||||
|
|
|
@ -464,7 +464,7 @@ struct XlaOpToStdScalarOp {
|
|||
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
|
||||
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
|
||||
std::false_type>::value>>
|
||||
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
||||
|
@ -472,8 +472,8 @@ struct XlaOpToStdScalarOp {
|
|||
args, b);
|
||||
}
|
||||
|
||||
// Implementation for HLO ops except xla_hlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
|
||||
// Implementation for HLO ops except mhlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
!std::is_same<LhloOpTy, std::false_type>::value>>
|
||||
|
@ -493,10 +493,11 @@ struct XlaOpToStdScalarOp {
|
|||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
|
||||
// Implementation for xla_hlo::CompareOp.
|
||||
template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
|
||||
HloOpTy, xla_hlo::CompareOp>::value>>
|
||||
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
// Implementation for mhlo::CompareOp.
|
||||
template <typename HloOpTy,
|
||||
typename =
|
||||
std::enable_if_t<std::is_same<HloOpTy, mhlo::CompareOp>::value>>
|
||||
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
|
|
|
@ -29,7 +29,7 @@ template <typename T>
|
|||
class OperationPass;
|
||||
class Pass;
|
||||
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
/// Lowers HLO control flow ops to the Standard dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
||||
|
@ -55,10 +55,10 @@ std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
|
|||
// necessary to export to XLA.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
||||
|
||||
// fuse xla_hlo ops to kLoop/kInput fusion patterns
|
||||
// fuse mhlo ops to kLoop/kInput fusion patterns
|
||||
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ class LLVMTypeConverter;
|
|||
class LowerToLLVMOptions;
|
||||
class OwningRewritePatternList;
|
||||
class BufferAssignmentPlacer;
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
// Collection of rewrite patterns for lowering a general dot product.
|
||||
void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
|
||||
|
@ -73,7 +73,7 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
|||
void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
// Static initialization for XLA dialect registration.
|
||||
static mlir::DialectRegistration<mlir::xla_hlo::XlaHloDialect> xla_hlo_ops;
|
||||
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
|
||||
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
|
||||
xla_chlo_ops;
|
||||
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;
|
||||
|
|
|
@ -60,7 +60,7 @@ limitations under the License.
|
|||
|
||||
namespace mlir {
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
|
||||
Attribute value, Type type,
|
||||
|
@ -68,8 +68,7 @@ Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
|
|||
// HLO dialect constants only support ElementsAttr unlike standard dialect
|
||||
// constant which supports all attributes.
|
||||
if (value.isa<ElementsAttr>())
|
||||
return builder.create<xla_hlo::ConstOp>(loc, type,
|
||||
value.cast<ElementsAttr>());
|
||||
return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -167,7 +166,7 @@ void ConstOp::build(OpBuilder& builder, OperationState& result,
|
|||
}
|
||||
|
||||
// TODO: support other XLA specific types.
|
||||
assert(type && "unsupported attribute type for building xla_hlo.constant");
|
||||
assert(type && "unsupported attribute type for building mhlo.constant");
|
||||
result.types.push_back(type);
|
||||
result.addAttribute("value", value);
|
||||
}
|
||||
|
@ -387,7 +386,7 @@ static LogicalResult Verify(GetTupleElementOp op) {
|
|||
|
||||
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto tupleOp =
|
||||
dyn_cast_or_null<xla_hlo::TupleOp>(getOperand().getDefiningOp())) {
|
||||
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
|
||||
return tupleOp.getOperand(index().getLimitedValue());
|
||||
}
|
||||
|
||||
|
@ -693,10 +692,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
|
|||
}
|
||||
|
||||
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto real_op =
|
||||
dyn_cast_or_null<xla_hlo::RealOp>(getOperand(0).getDefiningOp());
|
||||
auto imag_op =
|
||||
dyn_cast_or_null<xla_hlo::ImagOp>(getOperand(1).getDefiningOp());
|
||||
auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp());
|
||||
auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp());
|
||||
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
|
||||
return real_op.getOperand();
|
||||
}
|
||||
|
@ -727,7 +724,7 @@ void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
|||
|
||||
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto complex_op =
|
||||
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
return complex_op.getOperand(1);
|
||||
}
|
||||
|
||||
|
@ -740,7 +737,7 @@ void RealOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
|||
|
||||
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto complex_op =
|
||||
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
return complex_op.getOperand(0);
|
||||
}
|
||||
|
||||
|
@ -1148,7 +1145,7 @@ static LogicalResult Verify(MapOp op) {
|
|||
// RecvOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Checks that the result type is of the form `tuple<any_type, xla_hlo::token>`
|
||||
// Checks that the result type is of the form `tuple<any_type, mhlo::token>`
|
||||
static LogicalResult Verify(RecvOp op) {
|
||||
auto result_ty = op.getResult().getType().cast<TupleType>();
|
||||
auto subtypes = result_ty.getTypes();
|
||||
|
@ -2020,7 +2017,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// xla_hlo Dialect Interfaces
|
||||
// mhlo Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
@ -2032,7 +2029,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
|
|||
BlockAndValueMapping& valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
// Operations in xla_hlo dialect are always legal to inline since they are
|
||||
// Operations in mhlo dialect are always legal to inline since they are
|
||||
// pure.
|
||||
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
|
||||
return true;
|
||||
|
@ -2041,7 +2038,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
|
|||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// xla_hlo Dialect Constructor
|
||||
// mhlo Dialect Constructor
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
XlaHloDialect::XlaHloDialect(MLIRContext* context)
|
||||
|
@ -2061,8 +2058,7 @@ Type XlaHloDialect::parseType(DialectAsmParser& parser) const {
|
|||
if (parser.parseKeyword(&data_type)) return Type();
|
||||
|
||||
if (data_type == "token") return TokenType::get(getContext());
|
||||
parser.emitError(parser.getNameLoc())
|
||||
<< "unknown xla_hlo type: " << data_type;
|
||||
parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -2071,7 +2067,7 @@ void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const {
|
|||
os << "token";
|
||||
return;
|
||||
}
|
||||
os << "<unknown xla_hlo type>";
|
||||
os << "<unknown mhlo type>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2106,5 +2102,5 @@ LogicalResult deriveShapeFromFirstOperand(
|
|||
return success();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace xla_chlo {
|
|||
namespace {
|
||||
|
||||
// Converts binary ops that statically are determined to not broadcast directly
|
||||
// to the corresponding xla_hlo non-broadcasting op.
|
||||
// to the corresponding mhlo non-broadcasting op.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
|
@ -63,7 +63,7 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
|||
};
|
||||
|
||||
// Converts a binary op with ranked broadcasting operands to explicitly
|
||||
// broadcast and invoke the corresponding xla_hlo non-broadcasting op.
|
||||
// broadcast and invoke the corresponding mhlo non-broadcasting op.
|
||||
// Note that dynamic broadcasting supported by this pattern is only valid for
|
||||
// "numpy" broadcasting semantics as defined here:
|
||||
// https://docs.scipy.org/doc/numpy/reference/ufuncs.html
|
||||
|
@ -136,7 +136,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
|||
// properly.
|
||||
auto lhs_broadcast_dimensions = llvm::to_vector<4>(
|
||||
llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
|
||||
Value broadcasted_lhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc,
|
||||
RankedTensorType::get(result_type.getShape(),
|
||||
lhs_type.getElementType()),
|
||||
|
@ -144,7 +144,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
|||
rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
|
||||
auto rhs_broadcast_dimensions = llvm::to_vector<4>(
|
||||
llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
|
||||
Value broadcasted_rhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc,
|
||||
RankedTensorType::get(result_type.getShape(),
|
||||
rhs_type.getElementType()),
|
||||
|
@ -182,21 +182,19 @@ struct HloBinaryElementwiseAdaptor {
|
|||
};
|
||||
|
||||
struct HloComplexAdaptor {
|
||||
static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op,
|
||||
Type result_type, Value broadcasted_lhs,
|
||||
Value broadcasted_rhs,
|
||||
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<xla_hlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
|
||||
struct HloCompareAdaptor {
|
||||
static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op,
|
||||
Type result_type, Value broadcasted_lhs,
|
||||
Value broadcasted_rhs,
|
||||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type,
|
||||
return builder.create<mhlo::CompareOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction());
|
||||
}
|
||||
|
@ -214,28 +212,27 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
|||
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
|
||||
patterns);
|
||||
|
||||
POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp);
|
||||
POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp);
|
||||
POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op);
|
||||
POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp);
|
||||
POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp);
|
||||
POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp);
|
||||
POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp);
|
||||
POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp);
|
||||
POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp);
|
||||
POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp);
|
||||
POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightArithmeticOp,
|
||||
xla_hlo::ShiftRightArithmeticOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp);
|
||||
POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp);
|
||||
POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp);
|
||||
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
||||
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
||||
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
||||
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
||||
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
||||
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
||||
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
|
||||
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
|
||||
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
|
||||
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
|
||||
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
||||
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
||||
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
||||
|
||||
// Broadcasting ops requiring special construction.
|
||||
PopulateForBinaryOp<BroadcastComplexOp, xla_hlo::ComplexOp,
|
||||
HloComplexAdaptor>(context, patterns);
|
||||
PopulateForBinaryOp<BroadcastCompareOp, xla_hlo::CompareOp,
|
||||
HloCompareAdaptor>(context, patterns);
|
||||
PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
|
||||
context, patterns);
|
||||
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
|
||||
context, patterns);
|
||||
}
|
||||
|
||||
} // namespace xla_chlo
|
||||
|
|
|
@ -32,8 +32,8 @@ struct TestChloLegalizeToHloPass
|
|||
OwningRewritePatternList conversionPatterns;
|
||||
|
||||
conversionTarget.addIllegalDialect<XlaHloClientDialect>();
|
||||
// Consider the xla_hlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>();
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mhlo::XlaHloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||
|
|
|
@ -37,7 +37,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
|
@ -128,7 +128,7 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
|||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
return success();
|
||||
|
@ -136,12 +136,12 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
|||
};
|
||||
|
||||
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
: public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
|
||||
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
||||
public:
|
||||
using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
|
||||
using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
Value resultBuffer = InsertDynamicAllocAndDealloc(
|
||||
|
@ -162,7 +162,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
|||
// and size of the target dimension if size-1 dimension expansion is
|
||||
// necessary.
|
||||
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
|
||||
xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
auto loc = op.getLoc();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
auto operand_shape = operand_type.getShape();
|
||||
|
@ -220,12 +220,12 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
|||
}
|
||||
};
|
||||
|
||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
|
||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
public:
|
||||
using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
|
||||
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::ReduceOp op, ArrayRef<Value> operands,
|
||||
mhlo::ReduceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
// TODO(b/137624192) Implement variadic reduce.
|
||||
|
@ -314,10 +314,10 @@ class HloToLhloTensorStoreOpConverter
|
|||
// "xla_lhlo.fusion"() ({
|
||||
// %0 = tensor_load %arg1 : memref<2x2xf32>
|
||||
// %1 = tensor_load %arg2 : memref<2x2xf32>
|
||||
// %2 = "xla_hlo.add"(%0, %1) :
|
||||
// %2 = "mhlo.add"(%0, %1) :
|
||||
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// %3 = tensor_load %arg0 : memref<2x2xf32>
|
||||
// %4 = "xla_hlo.multiply"(%2, %3) :
|
||||
// %4 = "mhlo.multiply"(%2, %3) :
|
||||
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// tensor_store %4, %arg3 : memref<2x2xf32>
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
|
@ -344,8 +344,8 @@ class HloToLhloTensorStoreOpConverter
|
|||
// FuncOp signature conversion example:
|
||||
//
|
||||
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
|
||||
// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>,
|
||||
// %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
|
||||
// tensor<4xf32> %1 = "mhlo.add"(%arg0, %0) : (tensor<4xf32>,
|
||||
// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
|
||||
// }
|
||||
//
|
||||
|
@ -388,7 +388,7 @@ struct HloLegalizeToLhlo
|
|||
target.addIllegalOp<mlir::TensorStoreOp>();
|
||||
target.addLegalOp<ModuleTerminatorOp>();
|
||||
target.addLegalOp<TensorFromElementsOp>();
|
||||
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
|
||||
target.addIllegalDialect<mhlo::XlaHloDialect>();
|
||||
|
||||
BufferAssignmentTypeConverter converter;
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
|
@ -442,38 +442,38 @@ void populateHLOToLHLOConversionPattern(
|
|||
// clang-format off
|
||||
patterns->insert<
|
||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||
HloToLhloOpConverter<xla_hlo::AbsOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AddOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AndOp>,
|
||||
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CeilOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CompareOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ComplexOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConstOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvertOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CopyOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CosOp>,
|
||||
HloToLhloOpConverter<xla_hlo::DivOp>,
|
||||
HloToLhloOpConverter<xla_hlo::DotOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ExpOp>,
|
||||
HloToLhloOpConverter<xla_hlo::GatherOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ImagOp>,
|
||||
HloToLhloOpConverter<xla_hlo::IotaOp>,
|
||||
HloToLhloOpConverter<xla_hlo::LogOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MaxOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MinOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MulOp>,
|
||||
HloToLhloOpConverter<xla_hlo::NegOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RealOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RemOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RsqrtOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ReshapeOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SelectOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SignOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SqrtOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SubOp>,
|
||||
HloToLhloOpConverter<xla_hlo::TanhOp>,
|
||||
HloToLhloOpConverter<mhlo::AbsOp>,
|
||||
HloToLhloOpConverter<mhlo::AddOp>,
|
||||
HloToLhloOpConverter<mhlo::AndOp>,
|
||||
HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
|
||||
HloToLhloOpConverter<mhlo::CeilOp>,
|
||||
HloToLhloOpConverter<mhlo::CompareOp>,
|
||||
HloToLhloOpConverter<mhlo::ComplexOp>,
|
||||
HloToLhloOpConverter<mhlo::ConstOp>,
|
||||
HloToLhloOpConverter<mhlo::ConvOp>,
|
||||
HloToLhloOpConverter<mhlo::ConvertOp>,
|
||||
HloToLhloOpConverter<mhlo::CopyOp>,
|
||||
HloToLhloOpConverter<mhlo::CosOp>,
|
||||
HloToLhloOpConverter<mhlo::DivOp>,
|
||||
HloToLhloOpConverter<mhlo::DotOp>,
|
||||
HloToLhloOpConverter<mhlo::ExpOp>,
|
||||
HloToLhloOpConverter<mhlo::GatherOp>,
|
||||
HloToLhloOpConverter<mhlo::ImagOp>,
|
||||
HloToLhloOpConverter<mhlo::IotaOp>,
|
||||
HloToLhloOpConverter<mhlo::LogOp>,
|
||||
HloToLhloOpConverter<mhlo::MaxOp>,
|
||||
HloToLhloOpConverter<mhlo::MinOp>,
|
||||
HloToLhloOpConverter<mhlo::MulOp>,
|
||||
HloToLhloOpConverter<mhlo::NegOp>,
|
||||
HloToLhloOpConverter<mhlo::RealOp>,
|
||||
HloToLhloOpConverter<mhlo::RemOp>,
|
||||
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
||||
HloToLhloOpConverter<mhlo::ReshapeOp>,
|
||||
HloToLhloOpConverter<mhlo::SelectOp>,
|
||||
HloToLhloOpConverter<mhlo::SignOp>,
|
||||
HloToLhloOpConverter<mhlo::SqrtOp>,
|
||||
HloToLhloOpConverter<mhlo::SubOp>,
|
||||
HloToLhloOpConverter<mhlo::TanhOp>,
|
||||
HloToLhloReduceOpConverter,
|
||||
HloToLhloTensorLoadOpConverter,
|
||||
HloToLhloTensorStoreOpConverter
|
||||
|
@ -489,5 +489,5 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
|||
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
|
||||
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -35,7 +35,7 @@ limitations under the License.
|
|||
using mlir::PassRegistration;
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
struct LegalizeControlFlow
|
||||
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
|
||||
|
@ -51,7 +51,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
|
|||
OpBuilder* builder) {
|
||||
for (auto& old_block : region->getBlocks()) {
|
||||
Block* block = mapper.lookup(&old_block);
|
||||
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
|
||||
auto return_op = dyn_cast<mhlo::ReturnOp>(block->getTerminator());
|
||||
if (!return_op) continue;
|
||||
builder->setInsertionPointToEnd(block);
|
||||
builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
|
||||
|
@ -61,7 +61,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
|
||||
LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
|
||||
Operation* op_inst = if_op.getOperation();
|
||||
mlir::OpBuilder builder(if_op);
|
||||
auto orig_block = op_inst->getBlock();
|
||||
|
@ -106,13 +106,13 @@ LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
|
||||
// Converts an XLA while loop into control flow. This generates a set of MLIR
|
||||
// blocks and branches, along with inlining the regions provided by the XLA
|
||||
// while loop. The structure should be similar to below:
|
||||
//
|
||||
// <prior operations>
|
||||
// %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
|
||||
// %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
|
||||
// <post operations>
|
||||
auto* op_inst = while_op.getOperation();
|
||||
mlir::OpBuilder builder(while_op);
|
||||
|
@ -147,7 +147,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
|||
// extract_element and conditional branch. This changes the block below:
|
||||
// ^cond(%0):
|
||||
// <inlined conditional region>
|
||||
// "xla_hlo".return(%1)
|
||||
// "mhlo".return(%1)
|
||||
//
|
||||
// Into:
|
||||
// ^cond(%0):
|
||||
|
@ -156,14 +156,14 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
|||
// cond_br %2, ^body(%0), ^tail(%0) // Branch.
|
||||
builder.setInsertionPointToStart(cond_block);
|
||||
|
||||
// Replace the xla_hlo::ReturnOp with a branch back to the condition block.
|
||||
// This is required as the xla_hlo::ReturnOp is used to mark the end of a
|
||||
// Replace the mhlo::ReturnOp with a branch back to the condition block.
|
||||
// This is required as the mhlo::ReturnOp is used to mark the end of a
|
||||
// block for regions nested inside of a operations (MLIR ReturnOp cannot be
|
||||
// nested within an non-function region).
|
||||
for (auto& block : while_op.cond()) {
|
||||
auto new_block = mapper.lookup(&block);
|
||||
|
||||
auto return_op = dyn_cast<xla_hlo::ReturnOp>(new_block->getTerminator());
|
||||
auto return_op = dyn_cast<mhlo::ReturnOp>(new_block->getTerminator());
|
||||
if (!return_op) continue;
|
||||
builder.setInsertionPointToEnd(new_block);
|
||||
|
||||
|
@ -183,7 +183,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
|||
// conditional block. This changes the block below:
|
||||
// ^body(%0):
|
||||
// <inlined body block>
|
||||
// "xla_hlo".return(%1)
|
||||
// "mhlo".return(%1)
|
||||
//
|
||||
// Into:
|
||||
// ^body(%0):
|
||||
|
@ -191,8 +191,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
|||
// br ^cond(%0) // Branch.
|
||||
for (auto& block : while_op.body()) {
|
||||
auto new_block = mapper.lookup(&block);
|
||||
auto return_op =
|
||||
dyn_cast<mlir::xla_hlo::ReturnOp>(new_block->getTerminator());
|
||||
auto return_op = dyn_cast<mlir::mhlo::ReturnOp>(new_block->getTerminator());
|
||||
if (!return_op) continue;
|
||||
builder.setInsertionPointToEnd(new_block);
|
||||
builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
|
||||
|
@ -224,14 +223,14 @@ void LegalizeControlFlow::runOnFunction() {
|
|||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
|
||||
mlir::xla_hlo::createLegalizeControlFlowPass() {
|
||||
mlir::mhlo::createLegalizeControlFlowPass() {
|
||||
return std::make_unique<LegalizeControlFlow>();
|
||||
}
|
||||
|
||||
static PassRegistration<mlir::xla_hlo::LegalizeControlFlow> legalize_cf_pass(
|
||||
static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass(
|
||||
"xla-legalize-control-flow",
|
||||
"Legalize from XLA control flow to MLIR control flow");
|
||||
|
|
|
@ -28,14 +28,14 @@ namespace mlir {
|
|||
namespace {
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"
|
||||
} // end anonymous namespace
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
|
||||
LogicalResult matchAndRewrite(mhlo::CompareOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto lhs = op.lhs();
|
||||
auto rhs = op.rhs();
|
||||
|
@ -68,11 +68,11 @@ class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
|||
}
|
||||
};
|
||||
|
||||
class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
|
||||
LogicalResult matchAndRewrite(mhlo::CompareOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto lhs = op.lhs();
|
||||
auto rhs = op.rhs();
|
||||
|
@ -109,11 +109,11 @@ class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
|||
// convert the integer constant to iota result type. For complex types, the real
|
||||
// part is replaced with the generated constant and the imaginary part is
|
||||
// replaced with zero tensor.
|
||||
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
||||
class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::IotaOp op,
|
||||
LogicalResult matchAndRewrite(mhlo::IotaOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_size = output_type.getNumElements();
|
||||
|
@ -168,8 +168,7 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
|||
loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
|
||||
auto imag_zeroes =
|
||||
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const,
|
||||
imag_zeroes);
|
||||
rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iota_const, imag_zeroes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -197,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
|
|||
/// Perform the lowering to standard dialect.
|
||||
void LegalizeToStandard::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
|
||||
mlir::mhlo::PopulateXlaToStdPatterns(&patterns, &getContext());
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeToStandard> legalize_pass(
|
||||
"xla-legalize-to-std", "Legalize from XLA dialect to standard dialect");
|
||||
|
||||
} // end namespace xla_hlo
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -84,13 +84,13 @@ Value TransposeReshape(Value arg, mlir::Location loc,
|
|||
transposed_shape.push_back(arg_shape[val]);
|
||||
}
|
||||
auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
|
||||
auto transpose_result = rewriter->create<mlir::xla_hlo::TransposeOp>(
|
||||
auto transpose_result = rewriter->create<mlir::mhlo::TransposeOp>(
|
||||
loc, transpose_type, arg, transpose_permutation_attr);
|
||||
|
||||
// Return the final result.
|
||||
auto reshaped_type =
|
||||
RankedTensorType::get({left_size, right_size}, element_type);
|
||||
return rewriter->create<mlir::xla_hlo::ReshapeOp>(loc, reshaped_type,
|
||||
return rewriter->create<mlir::mhlo::ReshapeOp>(loc, reshaped_type,
|
||||
transpose_result);
|
||||
}
|
||||
|
||||
|
@ -125,8 +125,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc,
|
|||
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
|
||||
}
|
||||
|
||||
struct GeneralDotConvert
|
||||
: public OpRewritePattern<mlir::xla_hlo::DotGeneralOp> {
|
||||
struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
|
||||
// Attempts to lower a General Dot operator to a standard Dot operator.
|
||||
// General dots include batching dimensions and can have collapsing
|
||||
// dimensions along any axis. Inserting correctly arrange transpose and
|
||||
|
@ -138,7 +137,7 @@ struct GeneralDotConvert
|
|||
explicit GeneralDotConvert(MLIRContext *context)
|
||||
: OpRewritePattern(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op,
|
||||
LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dot_element_type = mlir::getElementTypeOrSelf(op);
|
||||
|
||||
|
@ -162,10 +161,10 @@ struct GeneralDotConvert
|
|||
auto new_dot_type =
|
||||
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
|
||||
|
||||
auto new_dot_op = rewriter.create<mlir::xla_hlo::DotOp>(
|
||||
auto new_dot_op = rewriter.create<mlir::mhlo::DotOp>(
|
||||
op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config()));
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::xla_hlo::ReshapeOp>(op, op.getType(),
|
||||
rewriter.replaceOpWithNewOp<mlir::mhlo::ReshapeOp>(op, op.getType(),
|
||||
new_dot_op);
|
||||
return success();
|
||||
}
|
||||
|
@ -176,15 +175,14 @@ struct LegalizeGeneralDot
|
|||
/// Lower all general dots that can be represented as a non-batched matmul.
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
|
||||
&getContext());
|
||||
mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(
|
||||
void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
|
||||
OwningRewritePatternList *patterns, MLIRContext *ctx) {
|
||||
patterns->insert<GeneralDotConvert>(ctx);
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -86,5 +86,5 @@ void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
|
|||
patterns->insert<ClampWithBroadcastConvert>(context);
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -33,7 +33,7 @@ struct TestMaterializeBroadcastsPass
|
|||
ConversionTarget conversionTarget(getContext());
|
||||
OwningRewritePatternList conversionPatterns;
|
||||
|
||||
// Consider the xla_hlo dialect legal for tests.
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<XlaHloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
|
@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass
|
|||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::xla_hlo::TestMaterializeBroadcastsPass>
|
||||
pass("test-xla-materialize-broadcasts",
|
||||
static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass(
|
||||
"test-xla-materialize-broadcasts",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
|
|
|
@ -60,7 +60,7 @@ limitations under the License.
|
|||
// shape dialect once it is ready.
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
using llvm::EquivalenceClasses;
|
||||
|
@ -544,7 +544,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
|||
}
|
||||
|
||||
FusionOp fusion =
|
||||
b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs);
|
||||
b.create<mhlo::FusionOp>(fused_loc, output_types, inputs);
|
||||
Region& region = fusion.fused_computation();
|
||||
region.push_back(new Block);
|
||||
Block& block = region.front();
|
||||
|
@ -552,7 +552,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
|||
op->moveBefore(&block, block.end());
|
||||
}
|
||||
b.setInsertionPoint(&block, block.end());
|
||||
b.create<xla_hlo::ReturnOp>(fused_loc, outputs);
|
||||
b.create<mhlo::ReturnOp>(fused_loc, outputs);
|
||||
|
||||
for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
|
||||
Value output = std::get<0>(output_and_result);
|
||||
|
@ -572,8 +572,8 @@ std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
|
|||
return std::make_unique<XlaHloFusion>();
|
||||
}
|
||||
|
||||
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass(
|
||||
"xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns.");
|
||||
static PassRegistration<XlaHloFusion> mhlo_fusion_pass(
|
||||
"xla-hlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -81,5 +81,5 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
|
|||
return std::make_unique<SinkConstantsToControlFlow>();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -40,11 +40,11 @@ Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
|
|||
auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
|
||||
auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
|
||||
if (shape_value) {
|
||||
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc, result_type, value_1d, shape_value, dims);
|
||||
}
|
||||
assert(result_type.hasStaticShape());
|
||||
return rewriter.create<xla_hlo::BroadcastInDimOp>(loc, result_type, value_1d,
|
||||
return rewriter.create<mhlo::BroadcastInDimOp>(loc, result_type, value_1d,
|
||||
dims);
|
||||
}
|
||||
|
||||
|
@ -89,25 +89,25 @@ Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
|
|||
auto epsilon_tensor_attr =
|
||||
DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
|
||||
Value epsilon =
|
||||
rewriter.create<xla_hlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
|
||||
rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
|
||||
auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
|
||||
auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
|
||||
if (broadcast_to_type.hasStaticShape()) {
|
||||
return rewriter.create<xla_hlo::BroadcastInDimOp>(
|
||||
return rewriter.create<mhlo::BroadcastInDimOp>(
|
||||
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
|
||||
}
|
||||
Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
|
||||
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
|
||||
op->getLoc(), broadcast_to_type, epsilon, shape_value,
|
||||
/*broadcast_dims=*/dims);
|
||||
}
|
||||
|
||||
class UnfuseBatchNormInferencePattern
|
||||
: public OpRewritePattern<xla_hlo::BatchNormInferenceOp> {
|
||||
: public OpRewritePattern<mhlo::BatchNormInferenceOp> {
|
||||
public:
|
||||
using OpRewritePattern<xla_hlo::BatchNormInferenceOp>::OpRewritePattern;
|
||||
using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op,
|
||||
LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
// Enforce type invariants.
|
||||
// Note that we deduce the actual element type from the variance,
|
||||
|
@ -132,9 +132,9 @@ class UnfuseBatchNormInferencePattern
|
|||
if (!epsilon) {
|
||||
return failure();
|
||||
}
|
||||
Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(),
|
||||
bn_op.variance(), epsilon);
|
||||
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);
|
||||
Value stddev =
|
||||
rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon);
|
||||
stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev);
|
||||
|
||||
// Broadcast all terms.
|
||||
Value shape_value;
|
||||
|
@ -156,14 +156,13 @@ class UnfuseBatchNormInferencePattern
|
|||
|
||||
// Compute:
|
||||
// scale * (input - mean) / stddev + offset
|
||||
Value result = rewriter.create<xla_hlo::SubOp>(
|
||||
bn_op.getLoc(), bn_op.operand(), broadcast_mean);
|
||||
result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result,
|
||||
broadcast_scale);
|
||||
result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result,
|
||||
broadcast_stddev);
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result,
|
||||
broadcast_offset);
|
||||
Value result = rewriter.create<mhlo::SubOp>(bn_op.getLoc(), bn_op.operand(),
|
||||
broadcast_mean);
|
||||
result =
|
||||
rewriter.create<mhlo::MulOp>(bn_op.getLoc(), result, broadcast_scale);
|
||||
result =
|
||||
rewriter.create<mhlo::DivOp>(bn_op.getLoc(), result, broadcast_stddev);
|
||||
rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, result, broadcast_offset);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -180,5 +179,5 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
|
|||
patterns->insert<UnfuseBatchNormInferencePattern>(context);
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass
|
|||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::xla_hlo::TestUnfuseBatchNormPass> pass(
|
||||
static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass(
|
||||
"test-xla-unfuse-batch-norm",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
|
|
|
@ -182,7 +182,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
|
|||
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
|
||||
|
||||
// This code has been adapted from IREE's
|
||||
// (https://github.com/google/iree/) xla_hlo -> linalg conversion.
|
||||
// (https://github.com/google/iree/) mhlo -> linalg conversion.
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
|
@ -348,14 +348,14 @@ class BroadcastConverter
|
|||
|
||||
class HloBroadcastInDimConverter
|
||||
: public DataMovementOpConverter<HloBroadcastInDimConverter,
|
||||
xla_hlo::BroadcastInDimOp, false> {
|
||||
mhlo::BroadcastInDimOp, false> {
|
||||
public:
|
||||
using DataMovementOpConverter<HloBroadcastInDimConverter,
|
||||
xla_hlo::BroadcastInDimOp,
|
||||
mhlo::BroadcastInDimOp,
|
||||
false>::DataMovementOpConverter;
|
||||
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(
|
||||
xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) {
|
||||
mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
|
||||
auto resultType = getXLAOpResultType<false>(broadcastOp);
|
||||
auto operandType =
|
||||
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||
|
@ -845,7 +845,7 @@ struct HloLegalizeToLinalg
|
|||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
|
||||
|
||||
auto func = getFunction();
|
||||
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
|
@ -863,40 +863,40 @@ static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
|
|||
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
|
||||
} // namespace xla_lhlo
|
||||
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
|
||||
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||
HloBroadcastInDimConverter,
|
||||
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
|
||||
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
|
||||
ReverseConverter<xla_hlo::ReverseOp, false>,
|
||||
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
|
@ -905,5 +905,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
|||
|
||||
static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass(
|
||||
"hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect");
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -28,7 +28,7 @@ limitations under the License.
|
|||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
// TODO(frgossen): Make it variadic.
|
||||
|
@ -69,7 +69,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|||
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
|
||||
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
operandTy.getElementType());
|
||||
Value flatOperand = rewriter.create<xla_hlo::DynamicReshapeOp>(
|
||||
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, flatTensorTy, operand, flatShapeAsDimTensor);
|
||||
|
||||
// Generate IR for the actual operation.
|
||||
|
@ -80,7 +80,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|||
rewriter.getIndexType());
|
||||
Value shapeAsExtentTensor =
|
||||
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
|
||||
Value result = rewriter.create<xla_hlo::DynamicReshapeOp>(
|
||||
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, operandTy, flatResult, shapeAsExtentTensor);
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
|
@ -184,5 +184,5 @@ static PassRegistration<TransformUnrankedHloPass> transform_unranked_hlo_pass(
|
|||
"transform-unranked-hlo",
|
||||
"Realize element-wise operations on ranked tensors where possible");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -2,107 +2,107 @@
|
|||
|
||||
// CHECK-LABEL: add_fold
|
||||
func @add_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<[6, 8, 10, 12]>
|
||||
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<[6, 8, 10, 12]>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: add_scalar_fold
|
||||
func @add_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<1> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<6>
|
||||
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<1> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<6>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: add_fold_float
|
||||
func @add_fold_float() -> tensor<4xf64> {
|
||||
%0 = xla_hlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
|
||||
%1 = xla_hlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64>
|
||||
// CHECK: xla_hlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
|
||||
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
%0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: sub_scalar_fold
|
||||
func @sub_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<1> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<4>
|
||||
%2 = "xla_hlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<1> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<4>
|
||||
%2 = "mhlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: multiply_scalar_fold
|
||||
func @multiply_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<3> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<15>
|
||||
%2 = "xla_hlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<3> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<15>
|
||||
%2 = "mhlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: divide_scalar_fold
|
||||
func @divide_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<1>
|
||||
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<1>
|
||||
%2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: divide_fold_float
|
||||
func @divide_fold_float() -> tensor<4xf64> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: xla_hlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
|
||||
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
|
||||
%2 = "mhlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: max_scalar_fold
|
||||
func @max_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<7>
|
||||
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<7>
|
||||
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: max_fold_float
|
||||
func @max_fold_float() -> tensor<4xf64> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: xla_hlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
|
||||
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
|
||||
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: min_scalar_fold
|
||||
func @min_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<-5> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<-5>
|
||||
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<-5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<-5>
|
||||
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: min_fold_float
|
||||
func @min_fold_float() -> tensor<4xf64> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: xla_hlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
|
||||
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
|
||||
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_noop
|
||||
func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>
|
||||
%0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
|
||||
%0 = "mhlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[ARG]]
|
||||
return %0 : tensor<4xi32>
|
||||
|
@ -112,7 +112,7 @@ func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
|||
func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> {
|
||||
// CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32>
|
||||
// CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32>
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[ARG0]]
|
||||
return %0 : tensor<4xi32>
|
||||
|
@ -120,34 +120,34 @@ func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) ->
|
|||
|
||||
// CHECK-LABEL: concatenate_empty_bool
|
||||
func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> {
|
||||
// CHECK: xla_hlo.constant
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
|
||||
// CHECK: mhlo.constant
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
|
||||
|
||||
return %0 : tensor<0xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_empty_int
|
||||
func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> {
|
||||
// CHECK: xla_hlo.constant
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
|
||||
// CHECK: mhlo.constant
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
|
||||
|
||||
return %0 : tensor<0xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_empty_float
|
||||
func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
|
||||
// CHECK: xla_hlo.constant
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
|
||||
// CHECK: mhlo.constant
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
|
||||
|
||||
return %0 : tensor<0xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_const_1D
|
||||
func @concatenate_const_1D() -> tensor<4xi32> {
|
||||
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]>
|
||||
%0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32>
|
||||
%1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
|
||||
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[0, 1, 2, 3]>
|
||||
%0 = mhlo.constant dense<[0, 1]> : tensor<2xi32>
|
||||
%1 = mhlo.constant dense<[2, 3]> : tensor<2xi32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<4xi32>
|
||||
|
@ -155,11 +155,11 @@ func @concatenate_const_1D() -> tensor<4xi32> {
|
|||
|
||||
// CHECK-LABEL: concatenate_const_1D_float
|
||||
func @concatenate_const_1D_float() -> tensor<4xf32> {
|
||||
// CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
|
||||
// CHECK: [[VAL:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
|
||||
|
||||
%0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
|
||||
%1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
|
||||
%0 = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
|
||||
%1 = mhlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<4xf32>
|
||||
|
@ -167,12 +167,12 @@ func @concatenate_const_1D_float() -> tensor<4xf32> {
|
|||
|
||||
// CHECK-LABEL: concatenate_const_2D_vertical
|
||||
func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
|
||||
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
|
||||
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
|
||||
// CHECK-SAME: [0, 1], [2, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
|
||||
%1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
|
||||
%0 = mhlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
|
||||
%1 = mhlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<2x2xi32>
|
||||
|
@ -180,12 +180,12 @@ func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
|
|||
|
||||
// CHECK-LABEL: concatenate_const_2D_horizontal
|
||||
func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
|
||||
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
|
||||
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
|
||||
// CHECK-SAME: [0, 2], [1, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
|
||||
%1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
%0 = mhlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
|
||||
%1 = mhlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<2x2xi32>
|
||||
|
@ -193,40 +193,40 @@ func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
|
|||
|
||||
// CHECK-LABEL: dynamic_slice_variable_start
|
||||
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||
// CHECK: "xla_hlo.dynamic-slice"
|
||||
%1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||
// CHECK: "mhlo.dynamic-slice"
|
||||
%1 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||
return %1 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_slice_constant_start
|
||||
func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
|
||||
// CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>}
|
||||
// CHECK: return %[[RESULT]] : tensor<2xi32>
|
||||
%0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
|
||||
%0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%1 = "mhlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
|
||||
return %1 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape
|
||||
func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<?x4xi32> {
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
|
||||
// CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64>
|
||||
// CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK: return %[[RESULT]] : tensor<?x4xi32>
|
||||
%0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%1 = xla_hlo.constant dense<0> : tensor<i64>
|
||||
%2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
|
||||
%0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%1 = mhlo.constant dense<0> : tensor<i64>
|
||||
%2 = "mhlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
|
||||
return %2 : tensor<?x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_noop
|
||||
// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64>
|
||||
func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
|
||||
%0 = "xla_hlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
|
||||
%0 = "mhlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
|
||||
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<2x2xi64>
|
||||
|
@ -234,80 +234,80 @@ func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
|
|||
|
||||
// CHECK-LABEL: slice_1D_fold
|
||||
func @slice_1D_fold() -> tensor<2xi64> {
|
||||
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<[7, 9]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<[7, 9]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
return %1 : tensor<2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_1D_fp
|
||||
func @slice_1D_fp() -> tensor<2xf32> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
|
||||
// CHECK: xla_hlo.constant dense<[7.000000e+00, 9.000000e+00]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
|
||||
%0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
|
||||
// CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_1D_strided_fold
|
||||
func @slice_1D_strided_fold() -> tensor<2xi64> {
|
||||
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<[7, 10]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<[7, 10]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
return %1 : tensor<2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold
|
||||
func @slice_2D_fold() -> tensor<2x2xi64> {
|
||||
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: xla_hlo.constant dense<[
|
||||
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: mhlo.constant dense<[
|
||||
// CHECK-SAME: [6, 7],
|
||||
// CHECK-SAME: [10, 11]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
|
||||
return %1 : tensor<2x2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold_horizontal
|
||||
func @slice_2D_fold_horizontal() -> tensor<1x4xi64> {
|
||||
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: xla_hlo.constant dense<[
|
||||
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: mhlo.constant dense<[
|
||||
// CHECK-SAME: [0, 1, 2, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
|
||||
return %1 : tensor<1x4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold_vertical
|
||||
func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
|
||||
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: xla_hlo.constant dense<[
|
||||
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: mhlo.constant dense<[
|
||||
// CHECK-SAME: [2], [6], [10], [14]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
|
||||
return %1 : tensor<4x1xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_first
|
||||
func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
|
||||
// CHECK: return %arg0
|
||||
return %1 : tensor<1x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_second
|
||||
func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
|
||||
// CHECK: return %arg1
|
||||
return %1 : tensor<1x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_second_with_slice
|
||||
func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>)
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>)
|
||||
|
||||
// CHECK: return [[SLICE]]
|
||||
return %1 : tensor<1x4xf32>
|
||||
|
@ -315,9 +315,9 @@ func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<
|
|||
|
||||
// CHECK-LABEL: slice_concat_fold_middle
|
||||
func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
|
||||
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>)
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>)
|
||||
|
||||
// CHECK: return [[SLICE]]
|
||||
return %1 : tensor<1x5xf32>
|
||||
|
@ -325,11 +325,11 @@ func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %
|
|||
|
||||
// CHECK-LABEL: slice_concat_fold_two
|
||||
func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> {
|
||||
// CHECK: [[CONCAT:%.+]] = "xla_hlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64}
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
|
||||
// CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64}
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>)
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>)
|
||||
|
||||
// CHECK: return [[SLICE]]
|
||||
return %1 : tensor<2x5xf32>
|
||||
|
@ -338,72 +338,72 @@ func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg
|
|||
// CHECK-LABEL: func @broadcast_in_dim_identity
|
||||
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||
// CHECK: return %arg0
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
|
||||
return %0 : tensor<2x3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts
|
||||
func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: xla_hlo.broadcast_in_dim
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: mhlo.broadcast_in_dim
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation
|
||||
func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: xla_hlo.broadcast_in_dim
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: mhlo.broadcast_in_dim
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
|
||||
func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> {
|
||||
// CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
|
||||
%0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
|
||||
// CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
|
||||
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<5x4xf32>
|
||||
return %0 : tensor<5x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @complex_expand_fold
|
||||
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
|
||||
%1 = "xla_hlo.real"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%2 = "xla_hlo.imag"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
|
||||
%1 = "mhlo.real"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%2 = "mhlo.imag"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
// CHECK: return %arg0, %arg1
|
||||
return %1, %2 : tensor<4xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @complex_collapse_fold
|
||||
func @complex_collapse_fold(%arg0: tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>> {
|
||||
%0 = "xla_hlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%1 = "xla_hlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%2 = "xla_hlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
%0 = "mhlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%1 = "mhlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%2 = "mhlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
// CHECK: return %arg0
|
||||
return %2 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @dynamic_iota_is_static
|
||||
func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> {
|
||||
// CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
|
||||
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
|
||||
// CHECK: return [[RESULT]]
|
||||
%0 = "xla_hlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32>
|
||||
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @iota_not_lowered_to_constant
|
||||
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
|
||||
// CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
|
||||
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
|
||||
// CHECK: return [[RESULT]]
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @unary_einsum
|
||||
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
|
||||
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: "mhlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
|
||||
%0 = "mhlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -411,30 +411,30 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
|||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: return [[ARG]]
|
||||
%0 = "xla_hlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
%0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
|
||||
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
|
||||
// CHECK: xla_hlo.reshape
|
||||
%0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32>
|
||||
// CHECK: mhlo.reshape
|
||||
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32>
|
||||
return %0 : tensor<4x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: do_not_dce_while_with_outfeed
|
||||
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: xla_hlo.while
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
// CHECK: mhlo.while
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.create_token"() : () -> !xla_hlo.token
|
||||
%1 = "mhlo.create_token"() : () -> !mhlo.token
|
||||
// Side-effecting op outfeed present inside while.
|
||||
%2 = "xla_hlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor<i64>, !xla_hlo.token) -> !xla_hlo.token
|
||||
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
%2 = "mhlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor<i64>, !mhlo.token) -> !mhlo.token
|
||||
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %arg0 : tensor<i64>
|
||||
|
@ -442,15 +442,15 @@ func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
|
|||
|
||||
// CHECK-LABEL: dce_while_without_side_effect
|
||||
func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK-NOT: xla_hlo.while
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
// CHECK-NOT: mhlo.while
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.create_token"() : () -> !xla_hlo.token
|
||||
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
%1 = "mhlo.create_token"() : () -> !mhlo.token
|
||||
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %arg0 : tensor<i64>
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
// representative op for detailed broadcast semantics.
|
||||
// CHECK-LABEL: @addWithoutBroadcast
|
||||
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.add %arg0, %arg1
|
||||
// CHECK: mhlo.add %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -20,9 +20,9 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
|
|||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]]
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
|
||||
|
@ -41,9 +41,9 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
|||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
|
||||
|
@ -62,9 +62,9 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
|||
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
|
||||
|
@ -76,7 +76,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
|||
// Verifies that broadcast_dimensions validity checks are valid.
|
||||
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
|
||||
func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: mhlo.add
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<
|
|||
// Verifies that broadcast_dimensions validity checks are valid.
|
||||
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
|
||||
func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: mhlo.add
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
@ -113,7 +113,7 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1:
|
|||
// expansions. Tests below merely verify that the op has an expansion.
|
||||
// CHECK-LABEL: @andWithoutBroadcast
|
||||
func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: xla_hlo.and %arg0, %arg1
|
||||
// CHECK: mhlo.and %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
|||
// -----
|
||||
// CHECK-LABEL: @atan2WithoutBroadcast
|
||||
func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.atan2 %arg0, %arg1
|
||||
// CHECK: mhlo.atan2 %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
|
|||
// -----
|
||||
// CHECK-LABEL: @compareWithoutBroadcast
|
||||
func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> {
|
||||
// CHECK: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
@ -137,7 +137,7 @@ func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
|||
// -----
|
||||
// CHECK-LABEL: @complexWithoutBroadcast
|
||||
func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> {
|
||||
// CHECK: "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
// CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
return %0 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
|||
// -----
|
||||
// CHECK-LABEL: @divideWithoutBroadcast
|
||||
func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.divide %arg0, %arg1
|
||||
// CHECK: mhlo.divide %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -153,7 +153,7 @@ func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tens
|
|||
// -----
|
||||
// CHECK-LABEL: @maximumWithoutBroadcast
|
||||
func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.maximum %arg0, %arg1
|
||||
// CHECK: mhlo.maximum %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -161,7 +161,7 @@ func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
|||
// -----
|
||||
// CHECK-LABEL: @minimumWithoutBroadcast
|
||||
func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.minimum %arg0, %arg1
|
||||
// CHECK: mhlo.minimum %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -169,7 +169,7 @@ func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
|||
// -----
|
||||
// CHECK-LABEL: @multiplyWithoutBroadcast
|
||||
func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.multiply %arg0, %arg1
|
||||
// CHECK: mhlo.multiply %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -177,7 +177,7 @@ func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te
|
|||
// -----
|
||||
// CHECK-LABEL: @orWithoutBroadcast
|
||||
func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: xla_hlo.or %arg0, %arg1
|
||||
// CHECK: mhlo.or %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
@ -185,7 +185,7 @@ func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi
|
|||
// -----
|
||||
// CHECK-LABEL: @powerWithoutBroadcast
|
||||
func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.power %arg0, %arg1
|
||||
// CHECK: mhlo.power %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -193,7 +193,7 @@ func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
|
|||
// -----
|
||||
// CHECK-LABEL: @remainderWithoutBroadcast
|
||||
func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.remainder %arg0, %arg1
|
||||
// CHECK: mhlo.remainder %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -201,7 +201,7 @@ func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> t
|
|||
// -----
|
||||
// CHECK-LABEL: @shift_leftWithoutBroadcast
|
||||
func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.shift_left %arg0, %arg1
|
||||
// CHECK: mhlo.shift_left %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -209,7 +209,7 @@ func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) ->
|
|||
// -----
|
||||
// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast
|
||||
func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1
|
||||
// CHECK: mhlo.shift_right_arithmetic %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -217,7 +217,7 @@ func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor
|
|||
// -----
|
||||
// CHECK-LABEL: @shift_right_logicalWithoutBroadcast
|
||||
func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.shift_right_logical %arg0, %arg1
|
||||
// CHECK: mhlo.shift_right_logical %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -225,7 +225,7 @@ func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4x
|
|||
// -----
|
||||
// CHECK-LABEL: @subWithoutBroadcast
|
||||
func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.subtract %arg0, %arg1
|
||||
// CHECK: mhlo.subtract %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -233,7 +233,7 @@ func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<
|
|||
// -----
|
||||
// CHECK-LABEL: @xorWithoutBroadcast
|
||||
func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: xla_hlo.xor %arg0, %arg1
|
||||
// CHECK: mhlo.xor %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
// CHECK-LABEL: func @single_operand
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
%0 = "mhlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
|
@ -5,7 +5,7 @@
|
|||
// CHECK-LABEL: func @same_type
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @same_type(%arg: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
@ -15,8 +15,8 @@ func @same_type(%arg: tensor<f32>) -> tensor<f32> {
|
|||
// CHECK-LABEL: func @int_widening
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i64>
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i64>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
@ -26,8 +26,8 @@ func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
|
|||
// CHECK-LABEL: func @int_narrowing
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i16>
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<i16>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i16>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i16>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<i16>
|
||||
}
|
||||
|
@ -37,8 +37,8 @@ func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
|
|||
// CHECK-LABEL: func @float_int
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @float_int(%arg: tensor<f32>) -> tensor<i32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<f32>) -> tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<f32>) -> tensor<i32>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
@ -48,8 +48,8 @@ func @float_int(%arg: tensor<f32>) -> tensor<i32> {
|
|||
// CHECK-LABEL: func @int_float
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @int_float(%arg: tensor<i32>) -> tensor<f32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<f32>
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<f32>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
@ -59,8 +59,8 @@ func @int_float(%arg: tensor<i32>) -> tensor<f32> {
|
|||
// CHECK-LABEL: func @high_rank_tensor
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<2x3xf32>
|
||||
}
|
||||
|
@ -70,9 +70,9 @@ func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
|
|||
|
||||
// CHECK-LABEL: func @const_same_type
|
||||
func @const_same_type() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
@ -81,9 +81,9 @@ func @const_same_type() -> tensor<i32> {
|
|||
|
||||
// CHECK-LABEL: func @const_float_int
|
||||
func @const_float_int() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%cst = xla_hlo.constant dense<42.0> : tensor<f32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<f32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
@ -92,9 +92,9 @@ func @const_float_int() -> tensor<i32> {
|
|||
|
||||
// CHECK-LABEL: func @const_int_float
|
||||
func @const_int_float() -> tensor<f32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor<f32>
|
||||
%cst = xla_hlo.constant dense<4> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<f32>
|
||||
%cst = mhlo.constant dense<4> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
@ -103,9 +103,9 @@ func @const_int_float() -> tensor<f32> {
|
|||
|
||||
// CHECK-LABEL: func @const_negative_int_float
|
||||
func @const_negative_int_float() -> tensor<f32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
|
||||
%cst = xla_hlo.constant dense<-4> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
|
||||
%cst = mhlo.constant dense<-4> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
@ -114,9 +114,9 @@ func @const_negative_int_float() -> tensor<f32> {
|
|||
|
||||
// CHECK-LABEL: func @const_int_bf16
|
||||
func @const_int_bf16() -> tensor<bf16> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
|
||||
%cst = xla_hlo.constant dense<4> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
|
||||
%cst = mhlo.constant dense<4> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<bf16>
|
||||
}
|
||||
|
@ -125,9 +125,9 @@ func @const_int_bf16() -> tensor<bf16> {
|
|||
|
||||
// CHECK-LABEL: func @const_bf16_int
|
||||
func @const_bf16_int() -> tensor<i16> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i16>
|
||||
%cst = xla_hlo.constant dense<42.0> : tensor<bf16>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<bf16>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i16>
|
||||
}
|
||||
|
@ -136,9 +136,9 @@ func @const_bf16_int() -> tensor<i16> {
|
|||
|
||||
// CHECK-LABEL: func @const_int_narrowing
|
||||
func @const_int_narrowing() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<i64>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42> : tensor<i64>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
@ -147,9 +147,9 @@ func @const_int_narrowing() -> tensor<i32> {
|
|||
|
||||
// CHECK-LABEL: func @const_int_widening
|
||||
func @const_int_widening() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i64>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
@ -158,9 +158,9 @@ func @const_int_widening() -> tensor<i64> {
|
|||
|
||||
// CHECK-LABEL: func @const_negative_int_widening
|
||||
func @const_negative_int_widening() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-42> : tensor<i64>
|
||||
%cst = xla_hlo.constant dense<-42> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
|
||||
%cst = mhlo.constant dense<-42> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
@ -169,9 +169,9 @@ func @const_negative_int_widening() -> tensor<i64> {
|
|||
|
||||
// CHECK-LABEL: func @const_float_narrowing
|
||||
func @const_float_narrowing() -> tensor<f32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
|
||||
%cst = xla_hlo.constant dense<4.2> : tensor<f64>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
|
||||
%cst = mhlo.constant dense<4.2> : tensor<f64>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
@ -180,9 +180,9 @@ func @const_float_narrowing() -> tensor<f32> {
|
|||
|
||||
// CHECK-LABEL: func @const_f32_bf16
|
||||
func @const_f32_bf16() -> tensor<bf16> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
|
||||
%cst = xla_hlo.constant dense<42.0> : tensor<f32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<f32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<bf16>
|
||||
}
|
||||
|
@ -191,9 +191,9 @@ func @const_f32_bf16() -> tensor<bf16> {
|
|||
|
||||
// CHECK-LABEL: func @const_bf16_f64
|
||||
func @const_bf16_f64() -> tensor<f64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor<f64>
|
||||
%cst = xla_hlo.constant dense<4.2> : tensor<bf16>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.187500e+00> : tensor<f64>
|
||||
%cst = mhlo.constant dense<4.2> : tensor<bf16>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f64>
|
||||
}
|
||||
|
@ -202,9 +202,9 @@ func @const_bf16_f64() -> tensor<f64> {
|
|||
|
||||
// CHECK-LABEL: func @const_bf16_int
|
||||
func @const_bf16_int() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i64>
|
||||
%cst = xla_hlo.constant dense<42.0> : tensor<bf16>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<bf16>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
@ -214,11 +214,11 @@ func @const_bf16_int() -> tensor<i64> {
|
|||
|
||||
// CHECK-LABEL: func @const_high_rank_tensor
|
||||
func @const_high_rank_tensor() -> tensor<2x3xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
|
||||
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
|
||||
// CHECK-SAME: ]> : tensor<2x3xi32>
|
||||
%cst = xla_hlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
|
||||
%cst = mhlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
// BOTH-LABEL: func @attrs
|
||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
|
||||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
|
@ -28,11 +28,11 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
|||
|
||||
// BOTH-LABEL: func @func_op_long
|
||||
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
|
||||
%3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
%4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32>
|
||||
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
|
||||
%1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%2 = mhlo.add %arg0, %1 : tensor<4xf32>
|
||||
%3 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
%4 = mhlo.subtract %arg1, %3 : tensor<4xf32>
|
||||
%5 = mhlo.multiply %2, %4 : tensor<4xf32>
|
||||
return %5 : tensor<4xf32>
|
||||
}
|
||||
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||
|
@ -65,12 +65,12 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
|||
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
|
||||
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
|
||||
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
|
||||
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
|
||||
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
|
@ -86,7 +86,7 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
|||
// BOTH-LABEL: func @copy
|
||||
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.copy"(%tensor_operand)
|
||||
%tensor_result = "mhlo.copy"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -98,7 +98,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// BOTH-LABEL: func @exp
|
||||
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
|
||||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -110,7 +110,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// BOTH-LABEL: func @log
|
||||
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.log"(%tensor_operand)
|
||||
%tensor_result = "mhlo.log"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -125,7 +125,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
|||
%tensor_pred = tensor_load %pred : memref<2x2xi1>
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
|
||||
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -138,7 +138,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
|||
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs)
|
||||
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
|
||||
{comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
|
@ -151,7 +151,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
|
|||
// BOTH-LABEL: func @broadcast
|
||||
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<5xf32>
|
||||
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand)
|
||||
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
|
||||
{broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<5xf32>) -> tensor<10x5xf32>
|
||||
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
|
@ -170,7 +170,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
|||
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||
%shape = call @external_func() : () -> tensor<3xi64>
|
||||
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// BOTH: %[[SHAPE:.*]] = call @external_func()
|
||||
|
@ -226,7 +226,7 @@ func @complex(%real: memref<2x2xf32>,
|
|||
%result: memref<2x2xcomplex<f32>>) {
|
||||
%tensor_real = tensor_load %real : memref<2x2xf32>
|
||||
%tensor_imag = tensor_load %imag : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag)
|
||||
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
|
||||
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
|
||||
|
@ -238,7 +238,7 @@ func @complex(%real: memref<2x2xf32>,
|
|||
// BOTH-LABEL: func @real
|
||||
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "xla_hlo.real"(%tensor_operand)
|
||||
%tensor_result = "mhlo.real"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -250,7 +250,7 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
|||
// BOTH-LABEL: func @imag
|
||||
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "xla_hlo.imag"(%tensor_operand)
|
||||
%tensor_result = "mhlo.imag"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -261,7 +261,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
|||
|
||||
// BOTH-LABEL: func @iota
|
||||
func @iota(%result: memref<10xi32>) {
|
||||
%tensor_result = "xla_hlo.iota"()
|
||||
%tensor_result = "mhlo.iota"()
|
||||
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
|
||||
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
tensor_store %tensor_result, %result : memref<10xi32>
|
||||
|
@ -273,7 +273,7 @@ func @iota(%result: memref<10xi32>) {
|
|||
// BOTH-LABEL: func @abs
|
||||
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.abs"(%tensor_operand)
|
||||
%tensor_result = "mhlo.abs"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -285,7 +285,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// BOTH-LABEL: func @ceil
|
||||
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.ceil"(%tensor_operand)
|
||||
%tensor_result = "mhlo.ceil"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -297,7 +297,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// BOTH-LABEL: func @convert
|
||||
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.convert"(%tensor_operand)
|
||||
%tensor_result = "mhlo.convert"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH-NOT: tensor_store
|
||||
|
@ -310,7 +310,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// BOTH-LABEL: func @cos
|
||||
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.cosine"(%tensor_operand)
|
||||
%tensor_result = "mhlo.cosine"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -322,7 +322,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// BOTH-LABEL: func @neg
|
||||
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.negate"(%tensor_operand)
|
||||
%tensor_result = "mhlo.negate"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -334,7 +334,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// BOTH-LABEL: func @rsqrt
|
||||
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.rsqrt"(%tensor_operand)
|
||||
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -346,7 +346,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// BOTH-LABEL: func @sign
|
||||
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.sign"(%tensor_operand)
|
||||
%tensor_result = "mhlo.sign"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -358,7 +358,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// BOTH-LABEL: func @sqrt
|
||||
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.sqrt"(%tensor_operand)
|
||||
%tensor_result = "mhlo.sqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -370,7 +370,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
// BOTH-LABEL: func @tanh
|
||||
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.tanh"(%tensor_operand)
|
||||
%tensor_result = "mhlo.tanh"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -383,7 +383,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
|
@ -395,7 +395,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
|||
// Dynamic shape binary element-wise operation.
|
||||
// BOTH-LABEL: func @add_dyn
|
||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
%result = "xla_hlo.add"(%lhs, %rhs)
|
||||
%result = "mhlo.add"(%lhs, %rhs)
|
||||
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
|
@ -420,7 +420,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
|||
// Dynamic shape unary element-wise operation.
|
||||
// BOTH-LABEL: func @tanh_dyn
|
||||
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
||||
%result = "xla_hlo.tanh"(%arg0)
|
||||
%result = "mhlo.tanh"(%arg0)
|
||||
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
|
@ -448,7 +448,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
|||
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
|
||||
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
%dot = "xla_hlo.dot"(%arg0, %arg0)
|
||||
%dot = "mhlo.dot"(%arg0, %arg0)
|
||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
||||
// ESC: return %[[ALLOC]]
|
||||
|
@ -466,7 +466,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
|
|||
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
|
||||
// BOTH-SAME: window_strides = dense<[2, 1]>
|
||||
%out = "xla_hlo.convolution"(%filter, %input) {
|
||||
%out = "mhlo.convolution"(%filter, %input) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
|
|
|
@ -10,7 +10,7 @@ func @float_add(%lhs: tensor<2x2xf32>,
|
|||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
|
||||
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = addf %[[ARG0]], %[[ARG1]]
|
||||
// CHECK: linalg.yield %[[RESULT]]
|
||||
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ func @integer_add(%lhs: tensor<2x2xi32>,
|
|||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: addi
|
||||
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ func @float_mul(%lhs: tensor<2x2xf32>,
|
|||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: mulf
|
||||
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
|
|||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: muli
|
||||
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
@ -58,7 +58,7 @@ func @float_remainder(%lhs: tensor<2x2xf32>,
|
|||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: remf
|
||||
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
|
|||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: remi_signed
|
||||
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
@ -79,7 +79,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
|
|||
|
||||
// CHECK-LABEL: func @float_rsqrt
|
||||
func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%tensor_result = "xla_hlo.rsqrt"(%operand)
|
||||
%tensor_result = "mhlo.rsqrt"(%operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: rsqrt
|
||||
|
@ -93,7 +93,7 @@ func @float_sub(%lhs: tensor<2x2xf32>,
|
|||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subf
|
||||
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
@ -105,7 +105,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
|
|||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subi
|
||||
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
|
|||
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -126,7 +126,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||
func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: exp
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -136,7 +136,7 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||
func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: log
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -146,7 +146,7 @@ func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: ceilf
|
||||
%0 = "xla_hlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -156,7 +156,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||
func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: negf
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -166,7 +166,7 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: tanh
|
||||
%0 = "xla_hlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -177,7 +177,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
|
|||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: and
|
||||
%0 = "xla_hlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
@ -187,7 +187,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
|
|||
// CHECK-LABEL: func @float_cmp
|
||||
func @float_cmp(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
|
||||
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
|
@ -201,7 +201,7 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
|
|||
// CHECK-LABEL: func @int_cmp
|
||||
func @int_cmp(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
|
||||
: (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>)
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
|
@ -216,7 +216,7 @@ func @int_cmp(%lhs: tensor<2x2xi32>,
|
|||
func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cos
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -226,7 +226,7 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||
func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: sin
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -235,7 +235,7 @@ func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||
// CHECK-LABEL: func @copy
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
||||
%0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
|
||||
%0 = "mhlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
|
||||
return %0 : tensor<2x4x8xf32>
|
||||
}
|
||||
// CHECK: return [[ARG]] : tensor<2x4x8xf32>
|
||||
|
@ -245,7 +245,7 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
|||
// CHECK-LABEL: func @select
|
||||
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "xla_hlo.select"(%pred, %lhs, %rhs)
|
||||
%0 = "mhlo.select"(%pred, %lhs, %rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
@ -260,7 +260,7 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
|||
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @broadcast_scalar
|
||||
func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
|
||||
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
|
||||
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
|
||||
return %0: tensor<4x2x1xf32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
|
@ -273,7 +273,7 @@ func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
|
|||
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
|
||||
// CHECK-LABEL: func @broadcast
|
||||
func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
|
||||
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
|
||||
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
|
||||
return %0: tensor<4x2x1x4x?x16xf32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
|
@ -286,7 +286,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
|
|||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
|
||||
// CHECK-LABEL: func @broadcast_in_dim
|
||||
func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%operand)
|
||||
%0 = "mhlo.broadcast_in_dim"(%operand)
|
||||
{broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
|
||||
: (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32>
|
||||
return %0 : tensor<7x10x6x4x5xf32>
|
||||
|
@ -302,7 +302,7 @@ func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
|
|||
// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one
|
||||
func @broadcast_in_dim_with_one_to_one(
|
||||
%operand: tensor<1xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%operand)
|
||||
%0 = "mhlo.broadcast_in_dim"(%operand)
|
||||
{broadcast_dimensions = dense<[0]> : tensor<1xi64>}
|
||||
: (tensor<1xf32>) -> tensor<1x5xf32>
|
||||
return %0 : tensor<1x5xf32>
|
||||
|
@ -317,7 +317,7 @@ func @broadcast_in_dim_with_one_to_one(
|
|||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @broadcast_scalar
|
||||
func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%operand)
|
||||
%0 = "mhlo.broadcast_in_dim"(%operand)
|
||||
{broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
: (tensor<f32>) -> tensor<7x10x6xf32>
|
||||
return %0 : tensor<7x10x6xf32>
|
||||
|
@ -332,7 +332,7 @@ func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
|
|||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @transpose
|
||||
func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
|
||||
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
|
||||
: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
|
||||
return %0 : tensor<3x2x5x9xi32>
|
||||
}
|
||||
|
@ -344,7 +344,7 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
|||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK-LABEL: func @reshape_3D_2D
|
||||
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
|
||||
return %0 : tensor<12x42xi32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
|
@ -355,7 +355,7 @@ func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
|
|||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_4D_2D
|
||||
func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
|
||||
return %0 : tensor<12x42xi32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
|
@ -366,7 +366,7 @@ func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
|
|||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_2D_4D
|
||||
func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
|
||||
return %0 : tensor<12x1x42x1xi32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
|
@ -375,7 +375,7 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
|
|||
|
||||
// CHECK-LABEL: func @minf
|
||||
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "xla_hlo.minimum"(%lhs, %rhs)
|
||||
%0 = "mhlo.minimum"(%lhs, %rhs)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
@ -389,7 +389,7 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||
|
||||
// CHECK-LABEL: func @maxi
|
||||
func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
%0 = "xla_hlo.maximum"(%lhs, %rhs)
|
||||
%0 = "mhlo.maximum"(%lhs, %rhs)
|
||||
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
@ -404,7 +404,7 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
|||
// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()>
|
||||
// CHECK-LABEL: func @add_scalar
|
||||
func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -417,7 +417,7 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
|
|||
|
||||
func @reshape_collapse_single_dim
|
||||
(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
|
||||
return %0 : tensor<1x784xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
|
||||
|
@ -428,7 +428,7 @@ func @reshape_collapse_single_dim
|
|||
// -----
|
||||
|
||||
func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
|
||||
return %0 : tensor<2x4x3xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
|
||||
|
@ -440,7 +440,7 @@ func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
|
|||
// -----
|
||||
|
||||
func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
|
||||
return %0 : tensor<2x4x2xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||
|
@ -451,7 +451,7 @@ func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
|
|||
// -----
|
||||
|
||||
func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
|
||||
return %0 : tensor<1x4x2xf32>
|
||||
}
|
||||
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
|
@ -462,7 +462,7 @@ func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
|
|||
|
||||
func @reshape_multiple_collapse
|
||||
(%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
|
||||
return %0 : tensor<1x4x5x6xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
|
||||
|
@ -476,7 +476,7 @@ func @reshape_multiple_collapse
|
|||
|
||||
// CHECK-LABEL: func @convert_i32_to_f32
|
||||
func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
|
||||
return %result : tensor<2x2xf32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -488,7 +488,7 @@ func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
|
|||
|
||||
// CHECK-LABEL: func @convert_i16_to_i32
|
||||
func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
|
||||
return %result : tensor<2x2xi32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -500,7 +500,7 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
|
|||
|
||||
// CHECK-LABEL: func @convert_i32_to_i16
|
||||
func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
|
||||
return %result : tensor<2x2xi16>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -512,7 +512,7 @@ func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
|
|||
|
||||
// CHECK-LABEL: func @convert_f32_to_f64
|
||||
func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
|
||||
return %result : tensor<2x2xf64>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -524,7 +524,7 @@ func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
|
|||
|
||||
// CHECK-LABEL: func @convert_f64_to_f32
|
||||
func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
|
||||
return %result : tensor<2x2xf32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -536,7 +536,7 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
|
|||
|
||||
// CHECK-LABEL: func @convert_f32_to_i32
|
||||
func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
|
||||
return %result : tensor<2x2xi32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
|
@ -550,7 +550,7 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
|
|||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @reverse
|
||||
func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%result = "xla_hlo.reverse"(%input) {
|
||||
%result = "mhlo.reverse"(%input) {
|
||||
dimensions = dense<1> : tensor<1xi64>
|
||||
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %result : tensor<2x3xf32>
|
||||
|
|
|
@ -1,28 +1,28 @@
|
|||
// RUN: mlir-hlo-opt %s -inline | FileCheck %s
|
||||
|
||||
// Test case: Basic test of inlining into xla_hlo.while.
|
||||
// Test case: Basic test of inlining into mhlo.while.
|
||||
|
||||
// CHECK-LABEL: func @caller
|
||||
// CHECK: "xla_hlo.while"{{.*}}( {
|
||||
// CHECK: "mhlo.while"{{.*}}( {
|
||||
// CHECK: }, {
|
||||
// CHECK: "xla_hlo.exponential"
|
||||
// CHECK: "mhlo.exponential"
|
||||
// CHECK: })
|
||||
// CHECK-LABEL: func @callee
|
||||
|
||||
func @caller(%arg0: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^entry(%unused: tensor<f32>):
|
||||
"xla_hlo.return"(%pred) : (tensor<i1>) -> ()
|
||||
"mhlo.return"(%pred) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^entry(%0: tensor<f32>):
|
||||
%1 = call @callee(%0) : (tensor<f32>) -> (tensor<f32>)
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
} ) : (tensor<f32>) -> (tensor<f32>)
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
|
||||
func @callee(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
|
|
@ -4,21 +4,21 @@
|
|||
func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
//CHECK: br ^bb1(%arg0 : tensor<i64>)
|
||||
//CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>):
|
||||
//CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]])
|
||||
//CHECK: [[VAL1:%.+]] = "mhlo.compare"([[VAL0]], [[VAL0]])
|
||||
//CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor<i1>
|
||||
//CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor<i64>), ^bb3([[VAL0]] : tensor<i64>)
|
||||
//CHECK: ^bb2([[VAL3:%.+]]: tensor<i64>):
|
||||
//CHECK: [[VAL4:%.+]] = xla_hlo.add [[VAL3]], [[VAL3]]
|
||||
//CHECK: [[VAL4:%.+]] = mhlo.add [[VAL3]], [[VAL3]]
|
||||
//CHECK: br ^bb1([[VAL4]] : tensor<i64>)
|
||||
//CHECK: ^bb3([[VAL5:%.+]]: tensor<i64>):
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
|
||||
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
|
||||
%1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
|
||||
"mhlo.return"(%1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
// CHECK-NEXT: return [[VAL5]]
|
||||
|
@ -30,27 +30,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
|
|||
// CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor<f32>
|
||||
%cst = constant dense<1.000000e+01> : tensor<f32>
|
||||
|
||||
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
|
||||
// CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1>
|
||||
// CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>)
|
||||
%1 = "xla_hlo.if"(%0, %arg0, %arg0) ( {
|
||||
%1 = "mhlo.if"(%0, %arg0, %arg0) ( {
|
||||
|
||||
^bb0(%arg1: tensor<f32>):
|
||||
// CHECK: ^bb1([[VAL2:%.+]]: tensor<f32>):
|
||||
// CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: [[VAL3:%.+]] = "mhlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^bb3([[VAL3]] : tensor<f32>)
|
||||
%2 = "xla_hlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
|
||||
^bb0(%arg1: tensor<f32>):
|
||||
// CHECK: ^bb2([[VAL4:%.+]]: tensor<f32>):
|
||||
// CHECK: [[VAL5:%.+]] = "xla_hlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^bb3([[VAL5]] : tensor<f32>)
|
||||
%2 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: ^bb3([[VAL6:%.+]]: tensor<f32>):
|
||||
|
@ -62,27 +62,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
|
|||
func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
|
||||
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
|
||||
// CHECK: %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %2 = extract_element %1[] : tensor<i1>
|
||||
// CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
|
||||
// CHECK: ^[[BODY_ENTRY]](%3: tensor<i64>):
|
||||
// CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor<i64>)
|
||||
// CHECK: ^[[BODY_SUCC]](%4: tensor<i64>):
|
||||
// CHECK: %5 = xla_hlo.add %4, %4 : tensor<i64>
|
||||
// CHECK: %5 = mhlo.add %4, %4 : tensor<i64>
|
||||
// CHECK: br ^[[COND_ENTRY]](%5 : tensor<i64>)
|
||||
// CHECK: ^[[EXIT]](%6: tensor<i64>):
|
||||
// CHECK: return %6 : tensor<i64>
|
||||
// CHECK: }
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^cond_entry(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^body_entry(%arg1: tensor<i64>):
|
||||
br ^body_succ(%arg1: tensor<i64>)
|
||||
^body_succ(%0: tensor<i64>):
|
||||
%1 = xla_hlo.add %0, %0 : tensor<i64>
|
||||
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
|
||||
%1 = mhlo.add %0, %0 : tensor<i64>
|
||||
"mhlo.return"(%1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %0 : tensor<i64>
|
||||
|
@ -94,7 +94,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
|
|||
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
|
||||
// CHECK: br ^[[COND_SUCC:.+]](%0 : tensor<i64>)
|
||||
// CHECK: ^[[COND_SUCC]](%1: tensor<i64>):
|
||||
// CHECK: %2 = "xla_hlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %3 = extract_element %2[] : tensor<i1>
|
||||
// CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
|
||||
// CHECK: ^[[BODY_ENTRY]](%4: tensor<i64>):
|
||||
|
@ -102,15 +102,15 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
|
|||
// CHECK: ^[[EXIT]](%5: tensor<i64>):
|
||||
// CHECK: return %5 : tensor<i64>
|
||||
// CHECK: }
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^cond_entry(%arg1: tensor<i64>):
|
||||
br ^cond_succ(%arg1: tensor<i64>)
|
||||
^cond_succ(%0: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^body_entry(%arg1: tensor<i64>):
|
||||
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %0 : tensor<i64>
|
||||
|
@ -123,24 +123,24 @@ func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %
|
|||
// CHECK: ^[[THEN_ENTRY]](%1: tensor<f32>):
|
||||
// CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor<f32>)
|
||||
// CHECK: ^[[THEN_SUCC]](%2: tensor<f32>):
|
||||
// CHECK: %3 = "xla_hlo.log"(%2) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %3 = "mhlo.log"(%2) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^[[EXIT:.+]](%3 : tensor<f32>)
|
||||
// CHECK: ^[[ELSE_ENTRY]](%4: tensor<f32>):
|
||||
// CHECK: %5 = "xla_hlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %5 = "mhlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^[[EXIT]](%5 : tensor<f32>)
|
||||
// CHECK: ^[[EXIT]](%6: tensor<f32>):
|
||||
// CHECK: return %6 : tensor<f32>
|
||||
// CHECK: }
|
||||
%1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( {
|
||||
%1 = "mhlo.if"(%pred, %arg0, %arg1) ( {
|
||||
^then_entry(%arg2: tensor<f32>):
|
||||
br ^then_succ(%arg2: tensor<f32>)
|
||||
^then_succ(%0: tensor<f32>):
|
||||
%2 = "xla_hlo.log"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.log"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^else_entry(%arg2: tensor<f32>):
|
||||
%2 = "xla_hlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
|
|
@ -3,19 +3,19 @@
|
|||
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32>
|
||||
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32>
|
||||
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32>
|
||||
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32>
|
||||
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: return %4 : tensor<4xf32>
|
||||
return %4 : tensor<4xf32>
|
||||
|
@ -24,19 +24,19 @@ func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf
|
|||
// CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32>
|
||||
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32>
|
||||
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32>
|
||||
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32>
|
||||
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: return %4 : tensor<4xi32>
|
||||
return %4 : tensor<4xi32>
|
||||
|
@ -45,17 +45,17 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32
|
|||
// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
|
||||
func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32>
|
||||
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32>
|
||||
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32>
|
||||
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32>
|
||||
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
}
|
||||
|
@ -63,28 +63,28 @@ func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi
|
|||
// CHECK-LABEL: func @compare_float
|
||||
func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||
// CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32>
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32>
|
||||
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32>
|
||||
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32>
|
||||
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32>
|
||||
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @int_constant
|
||||
func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
|
||||
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<i32>
|
||||
%0 = "xla_hlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
|
||||
%0 = "mhlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
|
||||
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32>
|
||||
%1 = "xla_hlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
%1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32>
|
||||
%2 = "xla_hlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
%2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
|
||||
return %0, %1, %2: tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
|
||||
}
|
||||
|
@ -92,11 +92,11 @@ func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
|
|||
// CHECK-LABEL: func @float_constant
|
||||
func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
|
||||
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<f32>
|
||||
%0 = "xla_hlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
|
||||
%0 = "mhlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
|
||||
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32>
|
||||
%1 = "xla_hlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
%1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32>
|
||||
%2 = "xla_hlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
%2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
|
||||
return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
|
||||
}
|
||||
|
@ -105,7 +105,7 @@ func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
|
|||
// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
|
||||
func @iota.const.1() -> tensor<4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
@ -113,7 +113,7 @@ func @iota.const.1() -> tensor<4xi32> {
|
|||
// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
|
||||
func @iota.const.2() -> tensor<2x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
|
||||
return %0 : tensor<2x4xi32>
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ func @iota.const.2() -> tensor<2x4xi32> {
|
|||
// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
|
||||
func @iota.const.3() -> tensor<2x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
|
||||
return %0 : tensor<2x4xi32>
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ func @iota.const.3() -> tensor<2x4xi32> {
|
|||
// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
|
||||
func @iota.const.4() -> tensor<2x3x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
||||
return %0 : tensor<2x3x4xi32>
|
||||
}
|
||||
|
@ -137,7 +137,7 @@ func @iota.const.4() -> tensor<2x3x4xi32> {
|
|||
// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
|
||||
func @iota.const.5() -> tensor<2x3x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
||||
return %0 : tensor<2x3x4xi32>
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ func @iota.const.5() -> tensor<2x3x4xi32> {
|
|||
// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
|
||||
func @iota.const.6() -> tensor<2x3x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
||||
return %0 : tensor<2x3x4xi32>
|
||||
}
|
||||
|
@ -153,7 +153,7 @@ func @iota.const.6() -> tensor<2x3x4xi32> {
|
|||
// CHECK-LABEL: func @iota.const.f32
|
||||
func @iota.const.f32() -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -161,7 +161,7 @@ func @iota.const.f32() -> tensor<4xf32> {
|
|||
// CHECK-LABEL: func @iota.const.f64
|
||||
func @iota.const.f64() -> tensor<4xf64> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
|
@ -169,7 +169,7 @@ func @iota.const.f64() -> tensor<4xf64> {
|
|||
// CHECK-LABEL: func @iota.const.bf16
|
||||
func @iota.const.bf16() -> tensor<4xbf16> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xbf16>
|
||||
return %0 : tensor<4xbf16>
|
||||
}
|
||||
|
@ -178,8 +178,8 @@ func @iota.const.bf16() -> tensor<4xbf16> {
|
|||
func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
|
||||
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
|
||||
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
|
||||
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f32>>
|
||||
return %0 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
|
@ -188,8 +188,8 @@ func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
|
|||
func @iota.const.complex.f64() -> tensor<4xcomplex<f64>> {
|
||||
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
|
||||
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
|
||||
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f64>>
|
||||
return %0 : tensor<4xcomplex<f64>>
|
||||
}
|
||||
|
|
|
@ -396,9 +396,9 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
|
|||
"xla_lhlo.fusion"() ( {
|
||||
%0 = tensor_load %input1 : memref<10xf32>
|
||||
%1 = tensor_load %input2 : memref<10xf32>
|
||||
%2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%3 = tensor_load %input3 : memref<10xf32>
|
||||
%4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
tensor_store %4, %out : memref<10xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
} ) : () -> ()
|
||||
|
@ -803,15 +803,15 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
|
|||
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
|
||||
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{
|
||||
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
|
||||
|
@ -958,8 +958,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
|
|||
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
|
||||
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
|
||||
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
|
||||
%add = mhlo.add %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%add) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
scatter_dimension_numbers = {
|
||||
update_window_dims = dense<[1]> : tensor<1xi64>,
|
||||
|
@ -979,8 +979,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
|
|||
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = xla_hlo.add %a, %b : tensor<f32>
|
||||
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
|
||||
%c = mhlo.add %a, %b : tensor<f32>
|
||||
"mhlo.return"(%c) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -991,8 +991,8 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
|
|||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = xla_hlo.add %a, %b : tensor<f32>
|
||||
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
|
||||
%c = mhlo.add %a, %b : tensor<f32>
|
||||
"mhlo.return"(%c) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -1012,8 +1012,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
|||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -1025,8 +1025,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
|||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -1038,8 +1038,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
|||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -2,14 +2,14 @@
|
|||
|
||||
// CHECK-LABEL: @add
|
||||
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
|
||||
%4 = "xla_hlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
|
||||
%4 = "mhlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
|
@ -17,14 +17,14 @@ func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
|||
|
||||
// CHECK-LABEL: @add_unranked
|
||||
func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
|
||||
%4 = "xla_hlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
|
||||
%4 = "mhlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
|
@ -32,14 +32,14 @@ func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
|
|||
|
||||
// CHECK-LABEL: @sub
|
||||
func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3
|
||||
%4 = "xla_hlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
|
||||
%4 = "mhlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
|
@ -47,14 +47,14 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
|||
|
||||
// CHECK-LABEL: @sub_unranked
|
||||
func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3
|
||||
%4 = "xla_hlo.subtract"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
|
||||
%4 = "mhlo.subtract"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
|
@ -62,18 +62,18 @@ func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
|
|||
|
||||
// CHECK-LABEL: @mul
|
||||
func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "xla_hlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "mhlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32>
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
|
@ -81,18 +81,18 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
|||
|
||||
// CHECK-LABEL: @mul_unranked
|
||||
func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "xla_hlo.multiply"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "mhlo.multiply"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32>
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
|
@ -100,36 +100,36 @@ func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
|
|||
|
||||
// CHECK-LABEL: @div
|
||||
func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
|
||||
|
||||
// Compute the numerator's real component:
|
||||
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
|
||||
|
||||
// Compute the real valued denominator as rhs * con(rhs):
|
||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
|
||||
|
||||
// Compute the numerator's imaginary component:
|
||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
|
||||
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
|
||||
|
||||
// Divide the numerator by the real valued denominator.
|
||||
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]]
|
||||
%4 = "xla_hlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
|
||||
%4 = "mhlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL10]], [[VAL11]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
|
@ -139,36 +139,36 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
|||
|
||||
// CHECK-LABEL: @div_unranked
|
||||
func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
|
||||
|
||||
// Compute the numerator's real component:
|
||||
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
|
||||
|
||||
// Compute the real valued denominator as rhs * con(rhs):
|
||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
|
||||
|
||||
// Compute the numerator's imaginary component:
|
||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
|
||||
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
|
||||
|
||||
// Divide the numerator by the real valued denominator.
|
||||
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]]
|
||||
%4 = "xla_hlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
|
||||
%4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL10]], [[VAL11]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
|
@ -176,14 +176,14 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
|
|||
|
||||
// CHECK-LABEL: @abs
|
||||
func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg0
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg1
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]])
|
||||
%1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg0
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]])
|
||||
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]]
|
||||
return %2 : tensor<2xf32>
|
||||
|
@ -191,16 +191,16 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
|
|||
|
||||
// CHECK-LABEL: @exp
|
||||
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]]
|
||||
%1 = "xla_hlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
|
||||
%1 = "mhlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%3 = "mhlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]], [[VAL4]]
|
||||
return %2, %3 : tensor<2xf32>, tensor<2xf32>
|
||||
|
@ -208,16 +208,16 @@ func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tenso
|
|||
|
||||
// CHECK-LABEL: @exp_unranked
|
||||
func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]]
|
||||
%1 = "xla_hlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "xla_hlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
|
||||
%1 = "mhlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "mhlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%3 = "mhlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]], [[VAL4]]
|
||||
return %2, %3 : tensor<*xf32>, tensor<*xf32>
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
|
||||
// CHECK-LABEL: @testDebatch1
|
||||
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
|
||||
// CHECK-DAG: [[R0:%.+]] = "xla_hlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = "xla_hlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
|
||||
// CHECK: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
|
||||
// CHECK-DAG: [[R0:%.+]] = "mhlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = "mhlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
|
||||
// CHECK: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
|
||||
|
||||
return %0 : tensor<1x1x3xf32>
|
||||
}
|
||||
|
@ -14,13 +14,13 @@ func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1
|
|||
|
||||
// CHECK-LABEL: @testDebatch2
|
||||
func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> {
|
||||
// CHECK-DAG: [[R0:%.+]] = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = "xla_hlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
|
||||
// CHECK-DAG: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
|
||||
// CHECK-DAG: [[R3:%.+]] = "xla_hlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
|
||||
// CHECK: [[R4:%.+]] = "xla_hlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
|
||||
// CHECK-DAG: [[R0:%.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
|
||||
// CHECK-DAG: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
|
||||
// CHECK-DAG: [[R3:%.+]] = "mhlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
|
||||
// CHECK: [[R4:%.+]] = "mhlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
|
||||
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
|
||||
return %0 : tensor<3x1x1xf32>
|
||||
}
|
||||
|
||||
|
@ -28,8 +28,8 @@ func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3
|
|||
|
||||
// CHECK-LABEL: @testBatchPassthrough
|
||||
func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> {
|
||||
// CHECK-NEXT: "xla_hlo.dot_general"(%arg0, %arg1)
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK-NEXT: "mhlo.dot_general"(%arg0, %arg1)
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32>
|
||||
return %0 : tensor<3x2x1xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
// CHECK-LABEL: @clampBroadcast
|
||||
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)
|
||||
func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>) -> tensor<4xf32> {
|
||||
// CHECK-DAG: %[[MIN_BC:.+]] = "xla_hlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[MAX_BC:.+]] = "xla_hlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK: "xla_hlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[MIN_BC:.+]] = "mhlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[MAX_BC:.+]] = "mhlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK: "mhlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "mhlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
|
404
tests/ops.mlir
404
tests/ops.mlir
File diff suppressed because it is too large
Load Diff
|
@ -4,11 +4,11 @@
|
|||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
|
||||
// CHECK: return %[[ARG0]]
|
||||
func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%2 = "xla_hlo.reduce"(%arg0, %0) ( {
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%2 = "mhlo.reduce"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%4 = xla_hlo.add %arg1, %arg2 : tensor<f32>
|
||||
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
|
||||
%4 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
|
||||
return %2 : tensor<4x8xf32>
|
||||
}
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
// CHECK-LABEL: func @const_fold_collapse_to_scalar
|
||||
func @const_fold_collapse_to_scalar() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<1x1xi32>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42> : tensor<1x1xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
@ -13,9 +13,9 @@ func @const_fold_collapse_to_scalar() -> tensor<i32> {
|
|||
|
||||
// CHECK-LABEL: func @const_fold_collapse_to_tensor
|
||||
func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<2xi32>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<1x2xi32>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<2xi32>
|
||||
%cst = mhlo.constant dense<42> : tensor<1x2xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
@ -24,9 +24,9 @@ func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
|
|||
|
||||
// CHECK-LABEL: func @const_fold_expand
|
||||
func @const_fold_expand() -> tensor<1xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<1xi32>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<1xi32>
|
||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<1xi32>
|
||||
}
|
||||
|
@ -35,9 +35,9 @@ func @const_fold_expand() -> tensor<1xi32> {
|
|||
|
||||
// CHECK-LABEL: func @const_fold_nontrivial
|
||||
func @const_fold_nontrivial() -> tensor<16xi64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<4x4xi64>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
|
||||
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<16xi64>
|
||||
}
|
||||
|
@ -46,9 +46,9 @@ func @const_fold_nontrivial() -> tensor<16xi64> {
|
|||
|
||||
// CHECK-LABEL: func @const_fold_flatten
|
||||
func @const_fold_flatten() -> tensor<16xi64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<4x4xi64>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
|
||||
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<16xi64>
|
||||
}
|
||||
|
@ -57,9 +57,9 @@ func @const_fold_flatten() -> tensor<16xi64> {
|
|||
|
||||
// CHECK-LABEL: func @const_fold_6
|
||||
func @const_fold_6() -> tensor<6xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
|
||||
%cst = xla_hlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
|
||||
%cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<6xi32>
|
||||
}
|
||||
|
@ -68,11 +68,11 @@ func @const_fold_6() -> tensor<6xi32> {
|
|||
|
||||
// CHECK-LABEL: func @const_fold_same_shape
|
||||
func @const_fold_same_shape() -> tensor<2x3xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
|
||||
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
|
||||
// CHECK-SAME: ]> : tensor<2x3xi32>
|
||||
%cst = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
|
||||
%cst = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
@ -81,9 +81,9 @@ func @const_fold_same_shape() -> tensor<2x3xi32> {
|
|||
|
||||
// CHECK-LABEL: func @const_fold_float
|
||||
func @const_fold_float() -> tensor<16xf64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
|
||||
%cst = xla_hlo.constant dense<4.2> : tensor<4x4xf64>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
|
||||
%cst = mhlo.constant dense<4.2> : tensor<4x4xf64>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<16xf64>
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ func @const_fold_float() -> tensor<16xf64> {
|
|||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
|
@ -103,10 +103,10 @@ func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
|
|||
// CHECK-LABEL: func @non_const_chained_reshape
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) {
|
||||
// CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
|
||||
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed
|
||||
}
|
||||
|
||||
|
@ -115,9 +115,9 @@ func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, ten
|
|||
// CHECK-LABEL: func @non_const_chained_reshape_unused_parent
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
|
||||
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %1 : tensor<6xi32>
|
||||
}
|
||||
|
@ -127,8 +127,8 @@ func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<
|
|||
// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %1 : tensor<2x3xi32>
|
||||
}
|
||||
|
@ -138,12 +138,12 @@ func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2
|
|||
// CHECK-LABEL: func @non_const_many_chained_reshapes
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
|
||||
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
|
||||
%1 = "xla_hlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
|
||||
%2 = "xla_hlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
|
||||
%3 = "xla_hlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
|
||||
%4 = "xla_hlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
|
||||
%2 = "mhlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
|
||||
%3 = "mhlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
|
||||
%4 = "mhlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %4 : tensor<1x2x4x3xi32>
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
// CHECK-LABEL: func @noop
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>)
|
||||
func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
|
||||
%0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
%0 = "mhlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK: return %[[ARG0]]
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
||||
|
|
|
@ -4,27 +4,27 @@
|
|||
|
||||
// CHECK-LABEL: func @sink_const_to_while
|
||||
func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK-NEXT: xla_hlo.while
|
||||
%c0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%c1 = xla_hlo.constant dense<2> : tensor<i64>
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
// CHECK-NEXT: mhlo.while
|
||||
%c0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%c1 = mhlo.constant dense<2> : tensor<i64>
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
// CHECK: %[[ARG1A:.+]]: tensor<i64>
|
||||
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
|
||||
// CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]])
|
||||
%1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
|
||||
// CHECK: "mhlo.compare"(%[[C0]], %[[ARG1A]])
|
||||
%1 = "mhlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
// CHECK: %[[ARG1B:.+]]: tensor<i64>
|
||||
// CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
|
||||
// CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]]
|
||||
%2 = xla_hlo.add %arg1, %arg1 : tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]]
|
||||
%3 = xla_hlo.add %c1, %2 : tensor<i64>
|
||||
// CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]]
|
||||
%4 = xla_hlo.add %c1, %3 : tensor<i64>
|
||||
"xla_hlo.return"(%4) : (tensor<i64>) -> ()
|
||||
// CHECK-DAG: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
|
||||
// CHECK-DAG: %[[ADD0:.+]] = mhlo.add %[[ARG1B]], %[[ARG1B]]
|
||||
%2 = mhlo.add %arg1, %arg1 : tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], %[[ADD0]]
|
||||
%3 = mhlo.add %c1, %2 : tensor<i64>
|
||||
// CHECK: %[[ADD2:.+]] = mhlo.add %[[C1]], %[[ADD1]]
|
||||
%4 = mhlo.add %c1, %3 : tensor<i64>
|
||||
"mhlo.return"(%4) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
@ -33,28 +33,28 @@ func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
|
|||
|
||||
// CHECK-LABEL: func @sink_const_to_conditional
|
||||
func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
%c0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%c1 = xla_hlo.constant dense<2> : tensor<i64>
|
||||
%0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
%1 = "xla_hlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
// CHECK: xla_hlo.if
|
||||
%2 = "xla_hlo.if"(%0, %1, %1) ( {
|
||||
%c0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%c1 = mhlo.constant dense<2> : tensor<i64>
|
||||
%0 = "mhlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
%1 = "mhlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
// CHECK: mhlo.if
|
||||
%2 = "mhlo.if"(%0, %1, %1) ( {
|
||||
^bb0(%arg1: tuple<tensor<i64>>):
|
||||
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]],
|
||||
%4 = xla_hlo.add %c0, %3 : tensor<i64>
|
||||
%5 = "xla_hlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"xla_hlo.return"(%5) : (tuple<tensor<i64>>) -> ()
|
||||
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
|
||||
%3 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD0:.+]] = mhlo.add %[[C0]],
|
||||
%4 = mhlo.add %c0, %3 : tensor<i64>
|
||||
%5 = "mhlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"mhlo.return"(%5) : (tuple<tensor<i64>>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tuple<tensor<i64>>):
|
||||
// CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
|
||||
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]],
|
||||
%7 = xla_hlo.add %c1, %6 : tensor<i64>
|
||||
%8 = "xla_hlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"xla_hlo.return"(%8) : (tuple<tensor<i64>>) -> ()
|
||||
// CHECK: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
|
||||
%6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]],
|
||||
%7 = mhlo.add %c1, %6 : tensor<i64>
|
||||
%8 = "mhlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"mhlo.return"(%8) : (tuple<tensor<i64>>) -> ()
|
||||
}) : (tensor<i1>, tuple<tensor<i64>>, tuple<tensor<i64>>) -> tuple<tensor<i64>>
|
||||
%9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
%9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
return %9 : tensor<i64>
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
// CHECK-LABEL: func @remove_noop
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
|
||||
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
|
||||
%0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<2x3x9x5xi32>
|
||||
}
|
||||
|
@ -13,8 +13,8 @@ func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
|
|||
// CHECK-LABEL: func @keep_real_transpose
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
|
||||
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
|
||||
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
|
||||
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
|
||||
return %0 : tensor<3x2x5x9xi32>
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,7 @@ func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
|||
// CHECK-LABEL: func @keep_same_shape_real_transpose
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> {
|
||||
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
|
||||
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
|
||||
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
|
||||
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
|
||||
return %0 : tensor<4x4xi32>
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @fold_access(%arg : tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
%tuple = "xla_hlo.tuple"(%arg) : (tensor<i32>) -> tuple<tensor<i32>>
|
||||
%element = "xla_hlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
|
||||
%tuple = "mhlo.tuple"(%arg) : (tensor<i32>) -> tuple<tensor<i32>>
|
||||
%element = "mhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
|
||||
return %element : tensor<i32>
|
||||
}
|
||||
|
|
|
@ -10,19 +10,19 @@ func @batchNormInference_2D_inner_features(
|
|||
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
|
||||
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
|
||||
-> (tensor<4x256xf32>) {
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor<f32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor<f32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
|
||||
tensor<256xf32>) -> tensor<4x256xf32>
|
||||
|
@ -36,12 +36,12 @@ func @batchNormInference_2D_inner_features(
|
|||
// the verifier to enforce the rest.
|
||||
// CHECK-SAME: %[[X:[^:]+]]
|
||||
// CHECK-SAME: %[[SCALE:[^:]+]]
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
|
||||
func @batchNormInference_4D_middle_features(
|
||||
%x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
|
||||
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
|
||||
-> (tensor<3x4x256x6xf32>) {
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} :
|
||||
(tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
|
||||
tensor<256xf32>) -> tensor<3x4x256x6xf32>
|
||||
|
@ -51,12 +51,12 @@ func @batchNormInference_4D_middle_features(
|
|||
// -----
|
||||
// CHECK-LABEL: @batchNormInference_f64
|
||||
// Validate that epsilon is properly promoted to f64
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f64>
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
|
||||
func @batchNormInference_f64(
|
||||
%x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>,
|
||||
%mean: tensor<256xf64>, %variance: tensor<256xf64>)
|
||||
-> (tensor<4x256xf64>) {
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>,
|
||||
tensor<256xf64>) -> tensor<4x256xf64>
|
||||
|
@ -66,12 +66,12 @@ func @batchNormInference_f64(
|
|||
// -----
|
||||
// CHECK-LABEL: @batchNormInference_f16
|
||||
// Validate that epsilon is properly promoted to f64
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f16>
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f16>
|
||||
func @batchNormInference_f16(
|
||||
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
|
||||
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
|
||||
-> (tensor<4x256xf16>) {
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
|
||||
tensor<256xf16>) -> tensor<4x256xf16>
|
||||
|
@ -85,7 +85,7 @@ func @batchNormInference_f16_overflow(
|
|||
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
|
||||
-> (tensor<4x256xf16>) {
|
||||
// expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}}
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 0.00000001 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
|
||||
tensor<256xf16>) -> tensor<4x256xf16>
|
||||
|
@ -108,26 +108,26 @@ func @batchNormInference_dynamic_shape(
|
|||
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 0.001 : f32, feature_index = 1 : i64} :
|
||||
(tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
|
||||
tensor<?xf32>) -> tensor<?x?x?x?xf32>
|
||||
|
|
|
@ -2,14 +2,14 @@
|
|||
|
||||
// CHECK-LABEL: func @multi_outputs_same
|
||||
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "xla_hlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:2 = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.subtract
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.return
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "mhlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:2 = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.subtract
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.return
|
||||
return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
|
@ -17,18 +17,18 @@ func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (ten
|
|||
|
||||
// CHECK-LABEL: func @multi_outputs_same_2
|
||||
func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "xla_hlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%3 = "xla_hlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%4 = "xla_hlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:3 = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.return
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%3 = "mhlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%4 = "mhlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:3 = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.return
|
||||
return %2, %3, %4 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
|
@ -36,9 +36,9 @@ func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (t
|
|||
|
||||
// CHECK-LABEL: func @multi_outputs_not_sure_same
|
||||
func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK-NOT: xla_hlo.fusion
|
||||
%1 = "xla_hlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%0 = "mhlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK-NOT: mhlo.fusion
|
||||
%1 = "mhlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
|
@ -46,25 +46,25 @@ func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
|
|||
|
||||
// CHECK-LABEL: func @reduce
|
||||
func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET0:.*]] = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.subtract
|
||||
// CHECK-NEXT: xla_hlo.return
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET0:.*]] = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.subtract
|
||||
// CHECK-NEXT: mhlo.return
|
||||
// Currently we do not support fuse arguments and ops without direct producer-consumer
|
||||
// relationship. Thus Reduce Op should not be fused with above two ops.
|
||||
|
||||
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "xla_hlo.reduce"(%arg0, %2) ( {
|
||||
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "mhlo.reduce"(%arg0, %2) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
|
||||
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// Above two ops should not be fused since reduce op can not be
|
||||
// fused with its consumer.
|
||||
// CHECK-NOT: xla_hlo.fusion
|
||||
// CHECK-NOT: mhlo.fusion
|
||||
|
||||
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
@ -73,25 +73,25 @@ func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>
|
|||
|
||||
// CHECK-LABEL: func @reduce_2
|
||||
func @reduce_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "xla_hlo.reduce"(%1, %2) ( {
|
||||
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "mhlo.reduce"(%1, %2) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
|
||||
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
// CHECK: %[[RET0:.*]]:2 = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.subtract
|
||||
// CHECK-NEXT: xla_hlo.constant
|
||||
// CHECK-NEXT: xla_hlo.reduce
|
||||
// CHECK: xla_hlo.return
|
||||
// CHECK: %[[RET0:.*]]:2 = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.subtract
|
||||
// CHECK-NEXT: mhlo.constant
|
||||
// CHECK-NEXT: mhlo.reduce
|
||||
// CHECK: mhlo.return
|
||||
|
||||
// Following op should not be fused with the above ops since reduce op can not be
|
||||
// fused with its consumer.
|
||||
// CHECK-NOT: xla_hlo.fusion
|
||||
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NOT: mhlo.fusion
|
||||
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
|
|
@ -9,15 +9,15 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
|||
%num_elements = shape.num_elements %shape
|
||||
%num_elements_as_index = shape.size_to_index %num_elements
|
||||
%flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex>
|
||||
%flat_a = "xla_hlo.dynamic_reshape"(%a, %flat_shape)
|
||||
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
|
||||
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
|
||||
// Apply operation.
|
||||
%flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
// Restore original shape.
|
||||
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex>
|
||||
%b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
|
||||
%b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
|
||||
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
|
||||
return %b : tensor<*xf32>
|
||||
|
@ -33,12 +33,12 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
|
||||
%b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %b : tensor<*xf32>
|
||||
}
|
||||
|
||||
|
@ -48,9 +48,9 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-LABEL: @sqrt_ranked
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>)
|
||||
func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
||||
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
// CHECK-NEXT: return %[[B]] : tensor<3x?xf32>
|
||||
%b = "xla_hlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
%b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
return %b : tensor<3x?xf32>
|
||||
}
|
||||
|
||||
|
@ -60,9 +60,9 @@ func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
|||
// CHECK-LABEL: @sqrt_static
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>)
|
||||
func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: return %[[B]] : tensor<2x3xf32>
|
||||
%b = "xla_hlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %b : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
|
@ -77,12 +77,12 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
|
||||
// CHECK: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_B:.*]] = "xla_hlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_RESULT:.*]] = xla_hlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
|
||||
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
|
||||
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<*xf32>
|
||||
%result = xla_hlo.add %a, %b : tensor<*xf32>
|
||||
%result = mhlo.add %a, %b : tensor<*xf32>
|
||||
return %result : tensor<*xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue