[MHLO:Linalg] Add support for lowering unsigned ops
This strips away the signedness with a type converter, using unrealized conversion casts. The rest is mostly mechanically pushing the original op down the pipeline so lowerings can see the original types. Signed types stay signless for now. This can be changed in the HLO bridge later. I did a pass over all ops and added unsigned lowerings where they were missing. There may be more. Currently the lowering will die at a later stage because it doesn't understand the unrealized casts. PiperOrigin-RevId: 371077494
This commit is contained in:
parent
09f6141562
commit
f4414fcd66
|
@ -44,41 +44,50 @@ template <>
|
||||||
struct LhloToScalarOp<lmhlo::AddOp> {
|
struct LhloToScalarOp<lmhlo::AddOp> {
|
||||||
using FOp = ::mlir::AddFOp;
|
using FOp = ::mlir::AddFOp;
|
||||||
using IOp = ::mlir::AddIOp;
|
using IOp = ::mlir::AddIOp;
|
||||||
|
using UOp = ::mlir::AddIOp;
|
||||||
using COp = ::mlir::complex::AddOp;
|
using COp = ::mlir::complex::AddOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct LhloToScalarOp<lmhlo::CompareOp> {
|
struct LhloToScalarOp<lmhlo::CompareOp> {
|
||||||
using FOp = ::mlir::CmpFOp;
|
using FOp = ::mlir::CmpFOp;
|
||||||
using IOp = ::mlir::CmpIOp;
|
using IOp = ::mlir::CmpIOp;
|
||||||
|
using UOp = ::mlir::CmpIOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct LhloToScalarOp<lmhlo::DivOp> {
|
struct LhloToScalarOp<lmhlo::DivOp> {
|
||||||
using FOp = ::mlir::DivFOp;
|
using FOp = ::mlir::DivFOp;
|
||||||
using IOp = ::mlir::SignedDivIOp;
|
using IOp = ::mlir::SignedDivIOp;
|
||||||
|
using UOp = ::mlir::UnsignedDivIOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct LhloToScalarOp<lmhlo::MulOp> {
|
struct LhloToScalarOp<lmhlo::MulOp> {
|
||||||
using FOp = ::mlir::MulFOp;
|
using FOp = ::mlir::MulFOp;
|
||||||
using IOp = ::mlir::MulIOp;
|
using IOp = ::mlir::MulIOp;
|
||||||
|
using UOp = ::mlir::MulIOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct LhloToScalarOp<lmhlo::RemOp> {
|
struct LhloToScalarOp<lmhlo::RemOp> {
|
||||||
using FOp = ::mlir::RemFOp;
|
using FOp = ::mlir::RemFOp;
|
||||||
using IOp = ::mlir::SignedRemIOp;
|
using IOp = ::mlir::SignedRemIOp;
|
||||||
|
using UOp = ::mlir::UnsignedRemIOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct LhloToScalarOp<lmhlo::SubOp> {
|
struct LhloToScalarOp<lmhlo::SubOp> {
|
||||||
using FOp = ::mlir::SubFOp;
|
using FOp = ::mlir::SubFOp;
|
||||||
using IOp = ::mlir::SubIOp;
|
using IOp = ::mlir::SubIOp;
|
||||||
|
using UOp = ::mlir::SubIOp;
|
||||||
using COp = ::mlir::complex::SubOp;
|
using COp = ::mlir::complex::SubOp;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
||||||
template <typename LhloOp>
|
template <typename LhloOp>
|
||||||
using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp;
|
using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp;
|
||||||
// Alias for the map from LHLO binary op type to STD integer op type.
|
// Alias for the map from LHLO binary op type to STD signed integer op type.
|
||||||
template <typename LhloOp>
|
template <typename LhloOp>
|
||||||
using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp;
|
using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp;
|
||||||
|
// Alias for the map from LHLO binary op type to STD unsigned integer op type.
|
||||||
|
template <typename LhloOp>
|
||||||
|
using ScalarUOp = typename LhloToScalarOp<LhloOp>::UOp;
|
||||||
// Alias for the map from LHLO binary op type to STD complex op type.
|
// Alias for the map from LHLO binary op type to STD complex op type.
|
||||||
template <typename LhloOp>
|
template <typename LhloOp>
|
||||||
using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
|
using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
|
||||||
|
@ -86,7 +95,8 @@ using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
struct MapLhloOpToScalarOpImpl {
|
struct MapLhloOpToScalarOpImpl {
|
||||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Type> arg_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -94,7 +104,8 @@ struct MapLhloOpToScalarOpImpl {
|
||||||
template <typename StdScalarOp>
|
template <typename StdScalarOp>
|
||||||
struct MapLhloOpToScalarOpImpl<StdScalarOp> {
|
struct MapLhloOpToScalarOpImpl<StdScalarOp> {
|
||||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Type> arg_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
|
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -102,41 +113,69 @@ struct MapLhloOpToScalarOpImpl<StdScalarOp> {
|
||||||
template <typename SupportedType, typename StdScalarOp, typename... Args>
|
template <typename SupportedType, typename StdScalarOp, typename... Args>
|
||||||
struct MapLhloOpToScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
struct MapLhloOpToScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Type> arg_types, ArrayRef<Value> args,
|
||||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
OpBuilder* b) {
|
||||||
if (element_type.isa<SupportedType>()) {
|
Type element_type = getElementTypeOrSelf(arg_types.front());
|
||||||
|
if (SupportedType{}(element_type)) {
|
||||||
return b->template create<StdScalarOp>(loc, result_types, args,
|
return b->template create<StdScalarOp>(loc, result_types, args,
|
||||||
mlir::None);
|
mlir::None);
|
||||||
}
|
}
|
||||||
return MapLhloOpToScalarOpImpl<Args...>{}(loc, result_types, args, b);
|
return MapLhloOpToScalarOpImpl<Args...>{}(loc, result_types, arg_types,
|
||||||
|
args, b);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct isAnyIntegerType {
|
||||||
|
bool operator()(Type t) { return t.isa<IntegerType>(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct isSignedIntegerType {
|
||||||
|
bool operator()(Type t) {
|
||||||
|
// Pretend that signless is signed. This will change eventually.
|
||||||
|
return t.isa<IntegerType>() && !t.isUnsignedInteger();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct isUnsignedIntegerType {
|
||||||
|
bool operator()(Type t) { return t.isUnsignedInteger(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct isFloatType {
|
||||||
|
bool operator()(Type t) { return t.isa<FloatType>(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct isComplexType {
|
||||||
|
bool operator()(Type t) { return t.isa<ComplexType>(); }
|
||||||
|
};
|
||||||
|
|
||||||
// Inserts the computation that corresponds to the body of the loop for lowered
|
// Inserts the computation that corresponds to the body of the loop for lowered
|
||||||
// LHLO unary/binary op. Returns the value for the result.
|
// LHLO unary/binary op. Returns the value for the result.
|
||||||
template <typename LhloOpTy>
|
template <typename LhloOpTy>
|
||||||
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
|
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
|
return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<LhloOpTy>,
|
||||||
ScalarFOp<LhloOpTy>>{}(loc, result_types, args,
|
isUnsignedIntegerType, ScalarUOp<LhloOpTy>,
|
||||||
b);
|
isFloatType, ScalarFOp<LhloOpTy>>{}(
|
||||||
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
Type element_type = getElementTypeOrSelf(arg_types.front());
|
||||||
if (element_type.isa<FloatType>()) {
|
if (element_type.isa<FloatType>()) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::AbsFOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
if (element_type.isa<ComplexType>()) {
|
if (element_type.isa<ComplexType>()) {
|
||||||
return MapLhloOpToScalarOpImpl<ComplexType, ::mlir::complex::AbsOp>{}(
|
return MapLhloOpToScalarOpImpl<isComplexType, ::mlir::complex::AbsOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isSignlessInteger() || element_type.isSignedInteger()) {
|
||||||
// lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
|
// lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
|
||||||
Value lhs = args[0];
|
Value lhs = args[0];
|
||||||
auto integer_type = element_type.dyn_cast<IntegerType>();
|
auto integer_type = element_type.dyn_cast<IntegerType>();
|
||||||
|
@ -156,40 +195,44 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
|
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ScalarIOp<lmhlo::AddOp>,
|
||||||
FloatType, ScalarFOp<lmhlo::AddOp>,
|
isFloatType, ScalarFOp<lmhlo::AddOp>,
|
||||||
ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
|
isComplexType, ScalarCOp<lmhlo::AddOp>>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::AndOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::Atan2Op>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::Atan2Op>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename PredicateType>
|
template <typename PredicateType>
|
||||||
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
inline Optional<PredicateType> getCmpPredicate(StringRef, bool) {
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
|
inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
|
||||||
StringRef comparison_direction) {
|
StringRef comparison_direction, bool is_signed) {
|
||||||
|
assert(is_signed && "cannot have an unsigned float!");
|
||||||
return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
|
return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
|
||||||
.Case("EQ", CmpFPredicate::OEQ)
|
.Case("EQ", CmpFPredicate::OEQ)
|
||||||
.Case("NE", CmpFPredicate::UNE)
|
.Case("NE", CmpFPredicate::UNE)
|
||||||
|
@ -202,14 +245,14 @@ inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
|
inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
|
||||||
StringRef comparison_direction) {
|
StringRef comparison_direction, bool is_signed) {
|
||||||
return llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
|
return llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
|
||||||
.Case("EQ", CmpIPredicate::eq)
|
.Case("EQ", CmpIPredicate::eq)
|
||||||
.Case("NE", CmpIPredicate::ne)
|
.Case("NE", CmpIPredicate::ne)
|
||||||
.Case("GE", CmpIPredicate::sge)
|
.Case("GE", is_signed ? CmpIPredicate::sge : CmpIPredicate::uge)
|
||||||
.Case("GT", CmpIPredicate::sgt)
|
.Case("GT", is_signed ? CmpIPredicate::sgt : CmpIPredicate::ugt)
|
||||||
.Case("LE", CmpIPredicate::sle)
|
.Case("LE", is_signed ? CmpIPredicate::sle : CmpIPredicate::ule)
|
||||||
.Case("LT", CmpIPredicate::slt)
|
.Case("LT", is_signed ? CmpIPredicate::slt : CmpIPredicate::ult)
|
||||||
.Default(llvm::None);
|
.Default(llvm::None);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,20 +260,21 @@ template <typename CompareOpTy>
|
||||||
inline Value MapCompareOpToStdScalarOp(Location loc,
|
inline Value MapCompareOpToStdScalarOp(Location loc,
|
||||||
StringRef comparison_direction,
|
StringRef comparison_direction,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
const auto& lhs = args[0];
|
const auto& lhs = args[0];
|
||||||
const auto& rhs = args[1];
|
const auto& rhs = args[1];
|
||||||
Type element_type = getElementTypeOrSelf(lhs.getType());
|
Type element_type = getElementTypeOrSelf(arg_types.front());
|
||||||
if (element_type.isSignlessInteger()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
Optional<CmpIPredicate> predicate =
|
Optional<CmpIPredicate> predicate = getCmpPredicate<CmpIPredicate>(
|
||||||
getCmpPredicate<CmpIPredicate>(comparison_direction);
|
comparison_direction, !element_type.isUnsignedInteger());
|
||||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||||
return b->create<ScalarIOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
|
return b->create<ScalarIOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||||
rhs);
|
rhs);
|
||||||
}
|
}
|
||||||
if (element_type.isa<FloatType>()) {
|
if (element_type.isa<FloatType>()) {
|
||||||
Optional<CmpFPredicate> predicate =
|
Optional<CmpFPredicate> predicate = getCmpPredicate<CmpFPredicate>(
|
||||||
getCmpPredicate<CmpFPredicate>(comparison_direction);
|
comparison_direction, /*is_signed=*/true);
|
||||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||||
return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
|
return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||||
rhs);
|
rhs);
|
||||||
|
@ -241,6 +285,7 @@ inline Value MapCompareOpToStdScalarOp(Location loc,
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return args.front();
|
return args.front();
|
||||||
|
@ -249,59 +294,66 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::ExpOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::ExpOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::ExpM1Op>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::CeilFOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||||
OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<complex::CreateOp>{}(loc, result_types, args,
|
return MapLhloOpToScalarOpImpl<complex::CreateOp>{}(loc, result_types,
|
||||||
b);
|
arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<complex::ReOp>{}(loc, result_types, args, b);
|
return MapLhloOpToScalarOpImpl<complex::ReOp>{}(loc, result_types, arg_types,
|
||||||
|
args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<complex::ImOp>{}(loc, result_types, args, b);
|
return MapLhloOpToScalarOpImpl<complex::ImOp>{}(loc, result_types, arg_types,
|
||||||
|
args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||||
OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
Type sourceType = getElementTypeOrSelf(args.front().getType());
|
Type sourceType = getElementTypeOrSelf(arg_types.front());
|
||||||
Type targetType = getElementTypeOrSelf(result_types.front());
|
Type targetType = getElementTypeOrSelf(result_types.front());
|
||||||
|
|
||||||
// A boolean value is considered to be unsigned when converting to
|
// A boolean value is considered to be unsigned when converting to
|
||||||
|
@ -342,7 +394,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
||||||
zero);
|
zero);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
|
if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) {
|
||||||
IntegerType src = sourceType.cast<IntegerType>();
|
IntegerType src = sourceType.cast<IntegerType>();
|
||||||
IntegerType res = targetType.cast<IntegerType>();
|
IntegerType res = targetType.cast<IntegerType>();
|
||||||
if (src.getWidth() > res.getWidth()) {
|
if (src.getWidth() > res.getWidth()) {
|
||||||
|
@ -352,6 +404,10 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
||||||
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
|
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
|
||||||
mlir::None);
|
mlir::None);
|
||||||
} else if (src.getWidth() < res.getWidth()) {
|
} else if (src.getWidth() < res.getWidth()) {
|
||||||
|
if (src.isUnsignedInteger()) {
|
||||||
|
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
|
||||||
|
mlir::None);
|
||||||
|
}
|
||||||
return b->create<mlir::SignExtendIOp>(loc, result_types, args,
|
return b->create<mlir::SignExtendIOp>(loc, result_types, args,
|
||||||
mlir::None);
|
mlir::None);
|
||||||
}
|
}
|
||||||
|
@ -367,6 +423,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
// Dot Op converter from lhlo to affine only accepts float and integer types.
|
// Dot Op converter from lhlo to affine only accepts float and integer types.
|
||||||
|
@ -375,16 +432,16 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
|
||||||
const auto& result = args[2];
|
const auto& result = args[2];
|
||||||
Type element_type = lhs.getType();
|
Type element_type = lhs.getType();
|
||||||
if (element_type.isa<FloatType>()) {
|
if (element_type.isa<FloatType>()) {
|
||||||
Value float_mul = MapLhloOpToScalarOpImpl<FloatType, ::mlir::MulFOp>{}(
|
Value float_mul = MapLhloOpToScalarOpImpl<isFloatType, ::mlir::MulFOp>{}(
|
||||||
loc, result_types, {lhs, rhs}, b);
|
loc, result_types, arg_types, {lhs, rhs}, b);
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::AddFOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::AddFOp>{}(
|
||||||
loc, result_types, {float_mul, result}, b);
|
loc, result_types, arg_types, {float_mul, result}, b);
|
||||||
}
|
}
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
Value int_mul = MapLhloOpToScalarOpImpl<IntegerType, ::mlir::MulIOp>{}(
|
Value int_mul = MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::MulIOp>{}(
|
||||||
loc, result_types, {lhs, rhs}, b);
|
loc, result_types, arg_types, {lhs, rhs}, b);
|
||||||
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::AddIOp>{}(
|
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::AddIOp>{}(
|
||||||
loc, result_types, {int_mul, result}, b);
|
loc, result_types, arg_types, {int_mul, result}, b);
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -392,34 +449,37 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::CosOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::CosOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::SinOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::SinOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::FloorFOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::FloorFOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||||
OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
if (args[0].getType().isa<FloatType>()) {
|
if (args[0].getType().isa<FloatType>()) {
|
||||||
auto pos_inf = APFloat::getInf(
|
auto pos_inf = APFloat::getInf(
|
||||||
args[0].getType().cast<FloatType>().getFloatSemantics());
|
args[0].getType().cast<FloatType>().getFloatSemantics());
|
||||||
|
@ -437,8 +497,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
struct CompareSelectOpToStdScalarOp {
|
struct CompareSelectOpToStdScalarOp {
|
||||||
static Value map(Location loc, StringRef comparison_direction,
|
static Value map(Location loc, StringRef comparison_direction,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||||
OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -450,28 +510,30 @@ template <typename SupportedType, typename StdCompareOp, typename Predicate,
|
||||||
struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
||||||
Args...> {
|
Args...> {
|
||||||
static Value map(Location loc, StringRef comparison_direction,
|
static Value map(Location loc, StringRef comparison_direction,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||||
OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
Type element_type = getElementTypeOrSelf(arg_types.front());
|
||||||
if (element_type.isa<SupportedType>()) {
|
if (element_type.isa<SupportedType>()) {
|
||||||
auto predicate = getCmpPredicate<Predicate>(comparison_direction);
|
auto predicate = getCmpPredicate<Predicate>(
|
||||||
|
comparison_direction, !element_type.isUnsignedInteger());
|
||||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||||
auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(),
|
auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(),
|
||||||
args[0], args[1]);
|
args[0], args[1]);
|
||||||
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
|
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
|
||||||
}
|
}
|
||||||
return CompareSelectOpToStdScalarOp<Args...>::map(loc, comparison_direction,
|
return CompareSelectOpToStdScalarOp<Args...>::map(
|
||||||
result_types, args, b);
|
loc, comparison_direction, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::LogOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::LogOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc,
|
inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc,
|
||||||
|
@ -493,8 +555,8 @@ inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc,
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||||
OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
auto ty = result_types.front().cast<FloatType>();
|
auto ty = result_types.front().cast<FloatType>();
|
||||||
Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0));
|
Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0));
|
||||||
Value x = args.front();
|
Value x = args.front();
|
||||||
|
@ -507,43 +569,47 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::Log1pOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::Log1pOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return LhloAlwaysPropagateNaN(
|
return LhloAlwaysPropagateNaN(
|
||||||
CompareSelectOpToStdScalarOp<
|
CompareSelectOpToStdScalarOp<
|
||||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
|
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
|
||||||
result_types, args,
|
result_types,
|
||||||
b),
|
arg_types, args, b),
|
||||||
args, loc, b);
|
args, loc, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return LhloAlwaysPropagateNaN(
|
return LhloAlwaysPropagateNaN(
|
||||||
CompareSelectOpToStdScalarOp<
|
CompareSelectOpToStdScalarOp<
|
||||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
|
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
|
||||||
result_types, args,
|
result_types,
|
||||||
b),
|
arg_types, args, b),
|
||||||
args, loc, b);
|
args, loc, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
assert(args.size() == 3 && "expected 3 arguments");
|
assert(args.size() == 3 && "expected 3 arguments");
|
||||||
|
@ -552,21 +618,22 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
|
||||||
Value ub = args[2];
|
Value ub = args[2];
|
||||||
|
|
||||||
// clamp(lb, x, ub) = max(min(x, ub), lb)
|
// clamp(lb, x, ub) = max(min(x, ub), lb)
|
||||||
Value min_x_ub =
|
Value min_x_ub = MapLhloOpToStdScalarOp<lmhlo::MinOp>(loc, result_types,
|
||||||
MapLhloOpToStdScalarOp<lmhlo::MinOp>(loc, result_types, {x, ub}, b);
|
arg_types, {x, ub}, b);
|
||||||
return MapLhloOpToStdScalarOp<lmhlo::MaxOp>(loc, result_types, {min_x_ub, lb},
|
return MapLhloOpToStdScalarOp<lmhlo::MaxOp>(loc, result_types, arg_types,
|
||||||
b);
|
{min_x_ub, lb}, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||||
if (element_type.isa<FloatType>()) {
|
if (element_type.isa<FloatType>()) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
// lmhlo.neg(x, result) -> result = sub(0, x)
|
// lmhlo.neg(x, result) -> result = sub(0, x)
|
||||||
|
@ -586,6 +653,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||||
|
@ -604,24 +672,27 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::OrOp>{}(loc, result_types,
|
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::OrOp>{}(
|
||||||
args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::RsqrtOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::RsqrtOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
lmhlo::PowOp::Adaptor adaptor(args);
|
lmhlo::PowOp::Adaptor adaptor(args);
|
||||||
|
@ -630,7 +701,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
||||||
auto result_type = result_types.front();
|
auto result_type = result_types.front();
|
||||||
if (result_type.isa<::mlir::FloatType>())
|
if (result_type.isa<::mlir::FloatType>())
|
||||||
return MapLhloOpToScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types,
|
return MapLhloOpToScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types,
|
||||||
args, b);
|
arg_types, args, b);
|
||||||
|
|
||||||
assert(result_type.isa<::mlir::IntegerType>() &&
|
assert(result_type.isa<::mlir::IntegerType>() &&
|
||||||
"only float and integer `pow` is supported right now");
|
"only float and integer `pow` is supported right now");
|
||||||
|
@ -699,39 +770,41 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||||
OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
|
return MapLhloOpToScalarOpImpl<::mlir::SelectOp>{}(loc, result_types,
|
||||||
b);
|
arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||||
OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
|
return MapLhloOpToScalarOpImpl<isAnyIntegerType, mlir::ShiftLeftOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||||
OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
|
return MapLhloOpToScalarOpImpl<isAnyIntegerType, mlir::SignedShiftRightOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||||
OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
|
return MapLhloOpToScalarOpImpl<isAnyIntegerType,
|
||||||
loc, result_types, args, b);
|
mlir::UnsignedShiftRightOp>{}(
|
||||||
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||||
|
@ -780,39 +853,43 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::SqrtOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::SqrtOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
|
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ScalarIOp<lmhlo::SubOp>,
|
||||||
FloatType, ScalarFOp<lmhlo::SubOp>,
|
isFloatType, ScalarFOp<lmhlo::SubOp>,
|
||||||
ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
|
isComplexType, ScalarCOp<lmhlo::SubOp>>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::TanhOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::TanhOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
|
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::XOrOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
|
@ -826,8 +903,9 @@ struct HloOpToStdScalarOp {
|
||||||
std::false_type>::value>>
|
std::false_type>::value>>
|
||||||
static Value map(HloOpTy op, ArrayRef<Type> result_types,
|
static Value map(HloOpTy op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
||||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(
|
||||||
args, b);
|
op.getLoc(), result_types, llvm::to_vector<4>(op->getOperandTypes()),
|
||||||
|
args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implementation for HLO ops except mhlo::CompareOp.
|
// Implementation for HLO ops except mhlo::CompareOp.
|
||||||
|
@ -837,8 +915,9 @@ struct HloOpToStdScalarOp {
|
||||||
!std::is_same<LhloOpTy, std::false_type>::value>>
|
!std::is_same<LhloOpTy, std::false_type>::value>>
|
||||||
static Value map(HloOpTy op, ArrayRef<Type> result_types,
|
static Value map(HloOpTy op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
|
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
|
||||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(
|
||||||
args, b);
|
op.getLoc(), result_types, llvm::to_vector<4>(op->getOperandTypes()),
|
||||||
|
args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implementation for lmhlo::CompareOp.
|
// Implementation for lmhlo::CompareOp.
|
||||||
|
@ -848,7 +927,8 @@ struct HloOpToStdScalarOp {
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
auto comparison_direction = op.comparison_direction();
|
auto comparison_direction = op.comparison_direction();
|
||||||
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||||
op.getLoc(), comparison_direction, result_types, args, b);
|
op.getLoc(), comparison_direction, result_types,
|
||||||
|
llvm::to_vector<4>(op->getOperandTypes()), args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implementation for mhlo::CompareOp.
|
// Implementation for mhlo::CompareOp.
|
||||||
|
@ -859,7 +939,8 @@ struct HloOpToStdScalarOp {
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
auto comparison_direction = op.comparison_direction();
|
auto comparison_direction = op.comparison_direction();
|
||||||
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||||
op.getLoc(), comparison_direction, result_types, args, b);
|
op.getLoc(), comparison_direction, result_types,
|
||||||
|
llvm::to_vector<4>(op->getOperandTypes()), args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implementation for LHLO ops except lmhlo::CompareOp.
|
// Implementation for LHLO ops except lmhlo::CompareOp.
|
||||||
|
@ -869,18 +950,20 @@ struct HloOpToStdScalarOp {
|
||||||
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
|
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
|
||||||
std::false_type>::value>>
|
std::false_type>::value>>
|
||||||
static Value map(Location loc, ArrayRef<Type> result_types,
|
static Value map(Location loc, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
ArrayRef<Type> arg_types, ArrayRef<Value> args, OpBuilder* b,
|
||||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, args, b);
|
unsigned i = 0) {
|
||||||
|
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, arg_types,
|
||||||
|
args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implementation for lmhlo::CompareOp.
|
// Implementation for lmhlo::CompareOp.
|
||||||
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
|
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
|
||||||
LhloOpTy, lmhlo::CompareOp>::value>>
|
LhloOpTy, lmhlo::CompareOp>::value>>
|
||||||
static Value map(Location loc, StringRef comparison_direction,
|
static Value map(Location loc, StringRef comparison_direction,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||||
OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||||
loc, comparison_direction, result_types, args, b);
|
loc, comparison_direction, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -57,8 +57,12 @@ void populateHLOToLHLOConversionPattern(MLIRContext *context,
|
||||||
|
|
||||||
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
|
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
|
||||||
void populateHLOToLinalgConversionPattern(MLIRContext *context,
|
void populateHLOToLinalgConversionPattern(MLIRContext *context,
|
||||||
|
TypeConverter &typeConverter,
|
||||||
OwningRewritePatternList *patterns);
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
// Converter to signless intergers to be used with linalg conversion patterns.
|
||||||
|
std::unique_ptr<TypeConverter> createHloToLinalgSignedIntegerConverter();
|
||||||
|
|
||||||
// Sets up legality definitions for materializing broadcasts.
|
// Sets up legality definitions for materializing broadcasts.
|
||||||
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
|
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
|
||||||
ConversionTarget *conversionTarget);
|
ConversionTarget *conversionTarget);
|
||||||
|
|
|
@ -241,16 +241,15 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||||
!(t.getElementType().isSignlessIntOrFloat() ||
|
!(t.getElementType().isSignlessIntOrFloat() ||
|
||||||
t.getElementType().isa<ComplexType>());
|
t.getElementType().isa<ComplexType>());
|
||||||
};
|
};
|
||||||
if (llvm::any_of(args,
|
if (llvm::any_of(op.getOperation()->getResultTypes(), [&](Type t) {
|
||||||
[&](Value v) {
|
return fail(this->typeConverter->convertType(t)
|
||||||
return fail(v.getType().dyn_cast<ShapedType>());
|
.template dyn_cast<ShapedType>());
|
||||||
}) ||
|
})) {
|
||||||
llvm::any_of(op.getOperation()->getResultTypes(),
|
|
||||||
[&](Type t) { return fail(t.dyn_cast<ShapedType>()); }))
|
|
||||||
return emitError(loc,
|
return emitError(loc,
|
||||||
"lhlo to linalg conversion expects ranked args of "
|
"hlo to linalg conversion expects ranked args of "
|
||||||
"signless int, float or complex element type with ")
|
"signless int, float or complex element type with ")
|
||||||
<< nloops << " parallel iterators: " << *(op.getOperation());
|
<< nloops << " parallel iterators: " << *(op.getOperation());
|
||||||
|
}
|
||||||
|
|
||||||
// Construct the indexing maps needed for linalg.generic ops.
|
// Construct the indexing maps needed for linalg.generic ops.
|
||||||
SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
|
SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
|
||||||
|
@ -270,12 +269,12 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||||
if (isLHLO) {
|
if (isLHLO) {
|
||||||
output_buffers.append(args.begin() + num_inputs, args.end());
|
output_buffers.append(args.begin() + num_inputs, args.end());
|
||||||
} else {
|
} else {
|
||||||
Value result = op.getOperation()->getResult(0);
|
Type result_type = this->typeConverter->convertType(
|
||||||
ShapedType result_type = result.getType().template cast<ShapedType>();
|
op.getOperation()->getResult(0).getType());
|
||||||
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
|
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
|
||||||
output_buffers.push_back(
|
output_buffers.push_back(GetInitTensor(
|
||||||
GetInitTensor(rewriter, loc, result_type, dyn_sizes));
|
rewriter, loc, result_type.cast<ShapedType>(), dyn_sizes));
|
||||||
op_result_types.push_back(result.getType());
|
op_result_types.push_back(result_type);
|
||||||
}
|
}
|
||||||
body_result_types = llvm::to_vector<4>(llvm::map_range(
|
body_result_types = llvm::to_vector<4>(llvm::map_range(
|
||||||
output_buffers, [](Value v) { return getElementTypeOrSelf(v); }));
|
output_buffers, [](Value v) { return getElementTypeOrSelf(v); }));
|
||||||
|
@ -1279,8 +1278,10 @@ class DynamicSliceConverter : public OpConversionPattern<mhlo::DynamicSliceOp> {
|
||||||
// map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it
|
// map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it
|
||||||
// to an lmhlo op and call the lmhlo implementation.
|
// to an lmhlo op and call the lmhlo implementation.
|
||||||
start_index = lmhlo::HloOpToStdScalarOp::map<lmhlo::ClampOp>(
|
start_index = lmhlo::HloOpToStdScalarOp::map<lmhlo::ClampOp>(
|
||||||
loc, start_index.getType(), ArrayRef<Value>{zero, start_index, ub},
|
loc, start_index.getType(),
|
||||||
&rewriter);
|
ArrayRef<Type>{start_index.getType(), start_index.getType(),
|
||||||
|
start_index.getType()},
|
||||||
|
ArrayRef<Value>{zero, start_index, ub}, &rewriter);
|
||||||
start_indices.push_back(
|
start_indices.push_back(
|
||||||
rewriter.create<IndexCastOp>(loc, index_type, start_index)
|
rewriter.create<IndexCastOp>(loc, index_type, start_index)
|
||||||
.getResult());
|
.getResult());
|
||||||
|
@ -2073,6 +2074,7 @@ struct TorchIndexSelectOpOnTensorsConversion
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
|
TypeConverter& typeConverter,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
||||||
|
@ -2128,10 +2130,65 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
|
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
|
||||||
SliceConverter<lmhlo::SliceOp>,
|
SliceConverter<lmhlo::SliceOp>,
|
||||||
TransposeConverter<lmhlo::TransposeOp>
|
TransposeConverter<lmhlo::TransposeOp>
|
||||||
>(context);
|
>(typeConverter, context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Converter that turns signed/unsigned integers types into signless types.
|
||||||
|
class RemoveSignTypeConverter : public TypeConverter {
|
||||||
|
public:
|
||||||
|
RemoveSignTypeConverter() {
|
||||||
|
addConversion([](Type type) { return type; });
|
||||||
|
|
||||||
|
addConversion(convertInteger);
|
||||||
|
addConversion(convertShapedType);
|
||||||
|
|
||||||
|
addArgumentMaterialization(materializeCastFromIllegal);
|
||||||
|
addSourceMaterialization(materializeCastToIllegal);
|
||||||
|
addTargetMaterialization(materializeCastFromIllegal);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static Type convertInteger(IntegerType int_type) {
|
||||||
|
return IntegerType::get(int_type.getContext(),
|
||||||
|
int_type.getIntOrFloatBitWidth());
|
||||||
|
}
|
||||||
|
|
||||||
|
static Type convertShapedType(ShapedType shaped_type) {
|
||||||
|
if (auto int_type = shaped_type.getElementType().dyn_cast<IntegerType>())
|
||||||
|
return shaped_type.clone(convertInteger(int_type));
|
||||||
|
return shaped_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
static llvm::Optional<Value> materializeCastFromIllegal(OpBuilder& builder,
|
||||||
|
Type type,
|
||||||
|
ValueRange inputs,
|
||||||
|
Location loc) {
|
||||||
|
Type from_type = getElementTypeOrSelf(inputs[0].getType());
|
||||||
|
Type to_type = getElementTypeOrSelf(type);
|
||||||
|
if ((!from_type.isSignedInteger() && !from_type.isUnsignedInteger()) ||
|
||||||
|
!to_type.isSignlessInteger())
|
||||||
|
return llvm::None;
|
||||||
|
// Use unrealized conversion casts to do signful->signless conversions.
|
||||||
|
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs[0])
|
||||||
|
->getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static llvm::Optional<Value> materializeCastToIllegal(OpBuilder& builder,
|
||||||
|
Type type,
|
||||||
|
ValueRange inputs,
|
||||||
|
Location loc) {
|
||||||
|
Type from_type = getElementTypeOrSelf(inputs[0].getType());
|
||||||
|
Type to_type = getElementTypeOrSelf(type);
|
||||||
|
if (!from_type.isSignlessInteger() ||
|
||||||
|
(!to_type.isSignedInteger() && !to_type.isUnsignedInteger()))
|
||||||
|
return llvm::None;
|
||||||
|
// Use unrealized conversion casts to do signless->signful conversions.
|
||||||
|
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs[0])
|
||||||
|
->getResult(0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Converts LHLO ops to Linalg generic.
|
// Converts LHLO ops to Linalg generic.
|
||||||
// Sample result for lmhlo::AddOp.
|
// Sample result for lmhlo::AddOp.
|
||||||
//
|
//
|
||||||
|
@ -2163,9 +2220,12 @@ struct LhloLegalizeToLinalgPass
|
||||||
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
||||||
math::MathDialect, memref::MemRefDialect,
|
math::MathDialect, memref::MemRefDialect,
|
||||||
StandardOpsDialect, AffineDialect>();
|
StandardOpsDialect, AffineDialect>();
|
||||||
|
target.addLegalOp<UnrealizedConversionCastOp>();
|
||||||
|
|
||||||
|
RemoveSignTypeConverter type_converter;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
populateLHLOToLinalgConversionPattern(func.getContext(), type_converter,
|
||||||
|
&patterns);
|
||||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
@ -2181,17 +2241,20 @@ struct HloLegalizeToLinalgPass
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns(&getContext());
|
MLIRContext& ctx = getContext();
|
||||||
ConversionTarget target(getContext());
|
OwningRewritePatternList patterns(&ctx);
|
||||||
|
ConversionTarget target(ctx);
|
||||||
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
||||||
math::MathDialect, StandardOpsDialect,
|
math::MathDialect, StandardOpsDialect,
|
||||||
tensor::TensorDialect, scf::SCFDialect>();
|
tensor::TensorDialect, scf::SCFDialect>();
|
||||||
|
|
||||||
// TODO: DimOp shouldn't be in MemRefDialect
|
// TODO: DimOp shouldn't be in MemRefDialect
|
||||||
target.addLegalOp<memref::DimOp>();
|
target.addLegalOp<memref::DimOp>();
|
||||||
|
target.addLegalOp<UnrealizedConversionCastOp>();
|
||||||
|
|
||||||
|
RemoveSignTypeConverter type_converter;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
mhlo::populateHLOToLinalgConversionPattern(&ctx, type_converter, &patterns);
|
||||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
@ -2209,6 +2272,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
|
|
||||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
|
TypeConverter& type_converter,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
|
@ -2272,7 +2336,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
ReduceOnTensorsConversion,
|
ReduceOnTensorsConversion,
|
||||||
ReduceWindowOpOnTensorsConversion,
|
ReduceWindowOpOnTensorsConversion,
|
||||||
TorchIndexSelectOpOnTensorsConversion,
|
TorchIndexSelectOpOnTensorsConversion,
|
||||||
PadOpOnTensorsConversion>(context);
|
PadOpOnTensorsConversion>(type_converter, context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
||||||
ReduceRegionXLAOpConversion<mhlo::MinOp>,
|
ReduceRegionXLAOpConversion<mhlo::MinOp>,
|
||||||
|
@ -2287,5 +2351,10 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||||
return std::make_unique<HloLegalizeToLinalgPass>();
|
return std::make_unique<HloLegalizeToLinalgPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<TypeConverter> createHloToLinalgSignedIntegerConverter() {
|
||||||
|
return std::make_unique<RemoveSignTypeConverter>();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -603,6 +603,21 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @maxu
|
||||||
|
func @maxu(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
|
||||||
|
%0 = "mhlo.maximum"(%lhs, %rhs)
|
||||||
|
: (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32>
|
||||||
|
return %0 : tensor<2x2xui32>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi32>
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32):
|
||||||
|
// CHECK-NEXT: %[[CMP:.*]] = cmpi ugt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()>
|
// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()>
|
||||||
// CHECK-LABEL: func @add_scalar
|
// CHECK-LABEL: func @add_scalar
|
||||||
func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
|
func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
|
||||||
|
@ -2196,3 +2211,43 @@ func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>)
|
||||||
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
return %concat : tensor<?x?xi32>
|
return %concat : tensor<?x?xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: unsigned_divide
|
||||||
|
func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: divi_unsigned
|
||||||
|
%0 = "mhlo.divide"(%lhs, %rhs) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32>
|
||||||
|
return %0 : tensor<2x2xui32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: unsigned_remainder
|
||||||
|
func @unsigned_remainder(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: remi_unsigned
|
||||||
|
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32>
|
||||||
|
return %0 : tensor<2x2xui32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: unsigned_convert
|
||||||
|
func @unsigned_convert(%in: tensor<2x2xui32>) -> tensor<2x2xui64> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: zexti
|
||||||
|
%0 = "mhlo.convert"(%in) : (tensor<2x2xui32>) -> tensor<2x2xui64>
|
||||||
|
return %0 : tensor<2x2xui64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: unsigned_compare
|
||||||
|
func @unsigned_compare(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xi1> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: cmpi ugt
|
||||||
|
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xi1>
|
||||||
|
return %0 : tensor<2x2xi1>
|
||||||
|
}
|
||||||
|
|
|
@ -93,6 +93,21 @@ func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @maxu
|
||||||
|
func @maxu(%lhs: memref<2x2xui32>, %rhs: memref<2x2xui32>,
|
||||||
|
%result: memref<2x2xui32>) {
|
||||||
|
"lmhlo.maximum"(%lhs, %rhs, %result)
|
||||||
|
: (memref<2x2xui32>, memref<2x2xui32>, memref<2x2xui32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i32):
|
||||||
|
// CHECK-NEXT: %[[CMP:.*]] = cmpi ugt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @and
|
// CHECK-LABEL: func @and
|
||||||
func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||||
%result: memref<2x2xi32>) {
|
%result: memref<2x2xi32>) {
|
||||||
|
|
Loading…
Reference in New Issue