[MHLO:Linalg] Add support for lowering unsigned ops
This strips away the signedness with a type converter, using unrealized conversion casts. The rest is mostly mechanically pushing the original op down the pipeline so lowerings can see the original types. Signed types stay signless for now. This can be changed in the HLO bridge later. I did a pass over all ops and added unsigned lowerings where they were missing. There may be more. Currently the lowering will die at a later stage because it doesn't understand the unrealized casts. PiperOrigin-RevId: 371077494
This commit is contained in:
parent
09f6141562
commit
f4414fcd66
|
@ -44,41 +44,50 @@ template <>
|
|||
struct LhloToScalarOp<lmhlo::AddOp> {
|
||||
using FOp = ::mlir::AddFOp;
|
||||
using IOp = ::mlir::AddIOp;
|
||||
using UOp = ::mlir::AddIOp;
|
||||
using COp = ::mlir::complex::AddOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<lmhlo::CompareOp> {
|
||||
using FOp = ::mlir::CmpFOp;
|
||||
using IOp = ::mlir::CmpIOp;
|
||||
using UOp = ::mlir::CmpIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<lmhlo::DivOp> {
|
||||
using FOp = ::mlir::DivFOp;
|
||||
using IOp = ::mlir::SignedDivIOp;
|
||||
using UOp = ::mlir::UnsignedDivIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<lmhlo::MulOp> {
|
||||
using FOp = ::mlir::MulFOp;
|
||||
using IOp = ::mlir::MulIOp;
|
||||
using UOp = ::mlir::MulIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<lmhlo::RemOp> {
|
||||
using FOp = ::mlir::RemFOp;
|
||||
using IOp = ::mlir::SignedRemIOp;
|
||||
using UOp = ::mlir::UnsignedRemIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<lmhlo::SubOp> {
|
||||
using FOp = ::mlir::SubFOp;
|
||||
using IOp = ::mlir::SubIOp;
|
||||
using UOp = ::mlir::SubIOp;
|
||||
using COp = ::mlir::complex::SubOp;
|
||||
};
|
||||
|
||||
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
||||
template <typename LhloOp>
|
||||
using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp;
|
||||
// Alias for the map from LHLO binary op type to STD integer op type.
|
||||
// Alias for the map from LHLO binary op type to STD signed integer op type.
|
||||
template <typename LhloOp>
|
||||
using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp;
|
||||
// Alias for the map from LHLO binary op type to STD unsigned integer op type.
|
||||
template <typename LhloOp>
|
||||
using ScalarUOp = typename LhloToScalarOp<LhloOp>::UOp;
|
||||
// Alias for the map from LHLO binary op type to STD complex op type.
|
||||
template <typename LhloOp>
|
||||
using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
|
||||
|
@ -86,7 +95,8 @@ using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
|
|||
template <typename... Args>
|
||||
struct MapLhloOpToScalarOpImpl {
|
||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
ArrayRef<Type> arg_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
@ -94,7 +104,8 @@ struct MapLhloOpToScalarOpImpl {
|
|||
template <typename StdScalarOp>
|
||||
struct MapLhloOpToScalarOpImpl<StdScalarOp> {
|
||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
ArrayRef<Type> arg_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
|
||||
}
|
||||
};
|
||||
|
@ -102,41 +113,69 @@ struct MapLhloOpToScalarOpImpl<StdScalarOp> {
|
|||
template <typename SupportedType, typename StdScalarOp, typename... Args>
|
||||
struct MapLhloOpToScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
if (element_type.isa<SupportedType>()) {
|
||||
ArrayRef<Type> arg_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = getElementTypeOrSelf(arg_types.front());
|
||||
if (SupportedType{}(element_type)) {
|
||||
return b->template create<StdScalarOp>(loc, result_types, args,
|
||||
mlir::None);
|
||||
}
|
||||
return MapLhloOpToScalarOpImpl<Args...>{}(loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<Args...>{}(loc, result_types, arg_types,
|
||||
args, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct isAnyIntegerType {
|
||||
bool operator()(Type t) { return t.isa<IntegerType>(); }
|
||||
};
|
||||
|
||||
struct isSignedIntegerType {
|
||||
bool operator()(Type t) {
|
||||
// Pretend that signless is signed. This will change eventually.
|
||||
return t.isa<IntegerType>() && !t.isUnsignedInteger();
|
||||
}
|
||||
};
|
||||
|
||||
struct isUnsignedIntegerType {
|
||||
bool operator()(Type t) { return t.isUnsignedInteger(); }
|
||||
};
|
||||
|
||||
struct isFloatType {
|
||||
bool operator()(Type t) { return t.isa<FloatType>(); }
|
||||
};
|
||||
|
||||
struct isComplexType {
|
||||
bool operator()(Type t) { return t.isa<ComplexType>(); }
|
||||
};
|
||||
|
||||
// Inserts the computation that corresponds to the body of the loop for lowered
|
||||
// LHLO unary/binary op. Returns the value for the result.
|
||||
template <typename LhloOpTy>
|
||||
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
|
||||
ScalarFOp<LhloOpTy>>{}(loc, result_types, args,
|
||||
b);
|
||||
return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<LhloOpTy>,
|
||||
isUnsignedIntegerType, ScalarUOp<LhloOpTy>,
|
||||
isFloatType, ScalarFOp<LhloOpTy>>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
Type element_type = getElementTypeOrSelf(arg_types.front());
|
||||
if (element_type.isa<FloatType>()) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::AbsFOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
if (element_type.isa<ComplexType>()) {
|
||||
return MapLhloOpToScalarOpImpl<ComplexType, ::mlir::complex::AbsOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isComplexType, ::mlir::complex::AbsOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
if (element_type.isSignlessInteger() || element_type.isSignedInteger()) {
|
||||
// lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
|
||||
Value lhs = args[0];
|
||||
auto integer_type = element_type.dyn_cast<IntegerType>();
|
||||
|
@ -156,40 +195,44 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
|||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
|
||||
FloatType, ScalarFOp<lmhlo::AddOp>,
|
||||
ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ScalarIOp<lmhlo::AddOp>,
|
||||
isFloatType, ScalarFOp<lmhlo::AddOp>,
|
||||
isComplexType, ScalarCOp<lmhlo::AddOp>>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::AndOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::Atan2Op>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::Atan2Op>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <typename PredicateType>
|
||||
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
||||
inline Optional<PredicateType> getCmpPredicate(StringRef, bool) {
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
|
||||
StringRef comparison_direction) {
|
||||
StringRef comparison_direction, bool is_signed) {
|
||||
assert(is_signed && "cannot have an unsigned float!");
|
||||
return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
|
||||
.Case("EQ", CmpFPredicate::OEQ)
|
||||
.Case("NE", CmpFPredicate::UNE)
|
||||
|
@ -202,14 +245,14 @@ inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
|
|||
|
||||
template <>
|
||||
inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
|
||||
StringRef comparison_direction) {
|
||||
StringRef comparison_direction, bool is_signed) {
|
||||
return llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
|
||||
.Case("EQ", CmpIPredicate::eq)
|
||||
.Case("NE", CmpIPredicate::ne)
|
||||
.Case("GE", CmpIPredicate::sge)
|
||||
.Case("GT", CmpIPredicate::sgt)
|
||||
.Case("LE", CmpIPredicate::sle)
|
||||
.Case("LT", CmpIPredicate::slt)
|
||||
.Case("GE", is_signed ? CmpIPredicate::sge : CmpIPredicate::uge)
|
||||
.Case("GT", is_signed ? CmpIPredicate::sgt : CmpIPredicate::ugt)
|
||||
.Case("LE", is_signed ? CmpIPredicate::sle : CmpIPredicate::ule)
|
||||
.Case("LT", is_signed ? CmpIPredicate::slt : CmpIPredicate::ult)
|
||||
.Default(llvm::None);
|
||||
}
|
||||
|
||||
|
@ -217,20 +260,21 @@ template <typename CompareOpTy>
|
|||
inline Value MapCompareOpToStdScalarOp(Location loc,
|
||||
StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
const auto& lhs = args[0];
|
||||
const auto& rhs = args[1];
|
||||
Type element_type = getElementTypeOrSelf(lhs.getType());
|
||||
if (element_type.isSignlessInteger()) {
|
||||
Optional<CmpIPredicate> predicate =
|
||||
getCmpPredicate<CmpIPredicate>(comparison_direction);
|
||||
Type element_type = getElementTypeOrSelf(arg_types.front());
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
Optional<CmpIPredicate> predicate = getCmpPredicate<CmpIPredicate>(
|
||||
comparison_direction, !element_type.isUnsignedInteger());
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
return b->create<ScalarIOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||
rhs);
|
||||
}
|
||||
if (element_type.isa<FloatType>()) {
|
||||
Optional<CmpFPredicate> predicate =
|
||||
getCmpPredicate<CmpFPredicate>(comparison_direction);
|
||||
Optional<CmpFPredicate> predicate = getCmpPredicate<CmpFPredicate>(
|
||||
comparison_direction, /*is_signed=*/true);
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||
rhs);
|
||||
|
@ -241,6 +285,7 @@ inline Value MapCompareOpToStdScalarOp(Location loc,
|
|||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return args.front();
|
||||
|
@ -249,59 +294,66 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
|
|||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::ExpOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::ExpOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::ExpM1Op>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::CeilFOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<complex::CreateOp>{}(loc, result_types, args,
|
||||
b);
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<complex::CreateOp>{}(loc, result_types,
|
||||
arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<complex::ReOp>{}(loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<complex::ReOp>{}(loc, result_types, arg_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<complex::ImOp>{}(loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<complex::ImOp>{}(loc, result_types, arg_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type sourceType = getElementTypeOrSelf(args.front().getType());
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
Type sourceType = getElementTypeOrSelf(arg_types.front());
|
||||
Type targetType = getElementTypeOrSelf(result_types.front());
|
||||
|
||||
// A boolean value is considered to be unsigned when converting to
|
||||
|
@ -342,7 +394,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
|||
zero);
|
||||
}
|
||||
}
|
||||
if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
|
||||
if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) {
|
||||
IntegerType src = sourceType.cast<IntegerType>();
|
||||
IntegerType res = targetType.cast<IntegerType>();
|
||||
if (src.getWidth() > res.getWidth()) {
|
||||
|
@ -352,6 +404,10 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
|||
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
|
||||
mlir::None);
|
||||
} else if (src.getWidth() < res.getWidth()) {
|
||||
if (src.isUnsignedInteger()) {
|
||||
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
|
||||
mlir::None);
|
||||
}
|
||||
return b->create<mlir::SignExtendIOp>(loc, result_types, args,
|
||||
mlir::None);
|
||||
}
|
||||
|
@ -367,6 +423,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
|||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
// Dot Op converter from lhlo to affine only accepts float and integer types.
|
||||
|
@ -375,16 +432,16 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
|
|||
const auto& result = args[2];
|
||||
Type element_type = lhs.getType();
|
||||
if (element_type.isa<FloatType>()) {
|
||||
Value float_mul = MapLhloOpToScalarOpImpl<FloatType, ::mlir::MulFOp>{}(
|
||||
loc, result_types, {lhs, rhs}, b);
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::AddFOp>{}(
|
||||
loc, result_types, {float_mul, result}, b);
|
||||
Value float_mul = MapLhloOpToScalarOpImpl<isFloatType, ::mlir::MulFOp>{}(
|
||||
loc, result_types, arg_types, {lhs, rhs}, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::AddFOp>{}(
|
||||
loc, result_types, arg_types, {float_mul, result}, b);
|
||||
}
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
Value int_mul = MapLhloOpToScalarOpImpl<IntegerType, ::mlir::MulIOp>{}(
|
||||
loc, result_types, {lhs, rhs}, b);
|
||||
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::AddIOp>{}(
|
||||
loc, result_types, {int_mul, result}, b);
|
||||
Value int_mul = MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::MulIOp>{}(
|
||||
loc, result_types, arg_types, {lhs, rhs}, b);
|
||||
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::AddIOp>{}(
|
||||
loc, result_types, arg_types, {int_mul, result}, b);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -392,34 +449,37 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
|
|||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::CosOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::CosOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::SinOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::SinOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::FloorFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::FloorFOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
if (args[0].getType().isa<FloatType>()) {
|
||||
auto pos_inf = APFloat::getInf(
|
||||
args[0].getType().cast<FloatType>().getFloatSemantics());
|
||||
|
@ -437,8 +497,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
|
|||
template <typename... Args>
|
||||
struct CompareSelectOpToStdScalarOp {
|
||||
static Value map(Location loc, StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
@ -450,28 +510,30 @@ template <typename SupportedType, typename StdCompareOp, typename Predicate,
|
|||
struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
||||
Args...> {
|
||||
static Value map(Location loc, StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
Type element_type = getElementTypeOrSelf(arg_types.front());
|
||||
if (element_type.isa<SupportedType>()) {
|
||||
auto predicate = getCmpPredicate<Predicate>(comparison_direction);
|
||||
auto predicate = getCmpPredicate<Predicate>(
|
||||
comparison_direction, !element_type.isUnsignedInteger());
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(),
|
||||
args[0], args[1]);
|
||||
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
|
||||
}
|
||||
return CompareSelectOpToStdScalarOp<Args...>::map(loc, comparison_direction,
|
||||
result_types, args, b);
|
||||
return CompareSelectOpToStdScalarOp<Args...>::map(
|
||||
loc, comparison_direction, result_types, arg_types, args, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::LogOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::LogOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc,
|
||||
|
@ -493,8 +555,8 @@ inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc,
|
|||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto ty = result_types.front().cast<FloatType>();
|
||||
Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0));
|
||||
Value x = args.front();
|
||||
|
@ -507,43 +569,47 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
|
|||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::Log1pOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::Log1pOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return LhloAlwaysPropagateNaN(
|
||||
CompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
|
||||
result_types, args,
|
||||
b),
|
||||
result_types,
|
||||
arg_types, args, b),
|
||||
args, loc, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return LhloAlwaysPropagateNaN(
|
||||
CompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
|
||||
result_types, args,
|
||||
b),
|
||||
result_types,
|
||||
arg_types, args, b),
|
||||
args, loc, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
assert(args.size() == 3 && "expected 3 arguments");
|
||||
|
@ -552,21 +618,22 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
|
|||
Value ub = args[2];
|
||||
|
||||
// clamp(lb, x, ub) = max(min(x, ub), lb)
|
||||
Value min_x_ub =
|
||||
MapLhloOpToStdScalarOp<lmhlo::MinOp>(loc, result_types, {x, ub}, b);
|
||||
return MapLhloOpToStdScalarOp<lmhlo::MaxOp>(loc, result_types, {min_x_ub, lb},
|
||||
b);
|
||||
Value min_x_ub = MapLhloOpToStdScalarOp<lmhlo::MinOp>(loc, result_types,
|
||||
arg_types, {x, ub}, b);
|
||||
return MapLhloOpToStdScalarOp<lmhlo::MaxOp>(loc, result_types, arg_types,
|
||||
{min_x_ub, lb}, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
if (element_type.isa<FloatType>()) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
// lmhlo.neg(x, result) -> result = sub(0, x)
|
||||
|
@ -586,6 +653,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
|||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
|
@ -604,24 +672,27 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
|||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::OrOp>{}(loc, result_types,
|
||||
args, b);
|
||||
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::OrOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::RsqrtOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::RsqrtOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
lmhlo::PowOp::Adaptor adaptor(args);
|
||||
|
@ -630,7 +701,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
|||
auto result_type = result_types.front();
|
||||
if (result_type.isa<::mlir::FloatType>())
|
||||
return MapLhloOpToScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types,
|
||||
args, b);
|
||||
arg_types, args, b);
|
||||
|
||||
assert(result_type.isa<::mlir::IntegerType>() &&
|
||||
"only float and integer `pow` is supported right now");
|
||||
|
@ -699,39 +770,41 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
|||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
|
||||
b);
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<::mlir::SelectOp>{}(loc, result_types,
|
||||
arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
|
||||
loc, result_types, args, b);
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<isAnyIntegerType, mlir::ShiftLeftOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
|
||||
loc, result_types, args, b);
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<isAnyIntegerType, mlir::SignedShiftRightOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
|
||||
loc, result_types, args, b);
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<isAnyIntegerType,
|
||||
mlir::UnsignedShiftRightOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
|
@ -780,39 +853,43 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
|||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::SqrtOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::SqrtOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
|
||||
FloatType, ScalarFOp<lmhlo::SubOp>,
|
||||
ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ScalarIOp<lmhlo::SubOp>,
|
||||
isFloatType, ScalarFOp<lmhlo::SubOp>,
|
||||
isComplexType, ScalarCOp<lmhlo::SubOp>>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::TanhOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::TanhOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
|
||||
loc, result_types, args, b);
|
||||
return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::XOrOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
@ -826,7 +903,8 @@ struct HloOpToStdScalarOp {
|
|||
std::false_type>::value>>
|
||||
static Value map(HloOpTy op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(
|
||||
op.getLoc(), result_types, llvm::to_vector<4>(op->getOperandTypes()),
|
||||
args, b);
|
||||
}
|
||||
|
||||
|
@ -837,7 +915,8 @@ struct HloOpToStdScalarOp {
|
|||
!std::is_same<LhloOpTy, std::false_type>::value>>
|
||||
static Value map(HloOpTy op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
|
||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(
|
||||
op.getLoc(), result_types, llvm::to_vector<4>(op->getOperandTypes()),
|
||||
args, b);
|
||||
}
|
||||
|
||||
|
@ -848,7 +927,8 @@ struct HloOpToStdScalarOp {
|
|||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
op.getLoc(), comparison_direction, result_types,
|
||||
llvm::to_vector<4>(op->getOperandTypes()), args, b);
|
||||
}
|
||||
|
||||
// Implementation for mhlo::CompareOp.
|
||||
|
@ -859,7 +939,8 @@ struct HloOpToStdScalarOp {
|
|||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
op.getLoc(), comparison_direction, result_types,
|
||||
llvm::to_vector<4>(op->getOperandTypes()), args, b);
|
||||
}
|
||||
|
||||
// Implementation for LHLO ops except lmhlo::CompareOp.
|
||||
|
@ -869,18 +950,20 @@ struct HloOpToStdScalarOp {
|
|||
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
|
||||
std::false_type>::value>>
|
||||
static Value map(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, args, b);
|
||||
ArrayRef<Type> arg_types, ArrayRef<Value> args, OpBuilder* b,
|
||||
unsigned i = 0) {
|
||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, arg_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
// Implementation for lmhlo::CompareOp.
|
||||
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
|
||||
LhloOpTy, lmhlo::CompareOp>::value>>
|
||||
static Value map(Location loc, StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
loc, comparison_direction, result_types, args, b);
|
||||
loc, comparison_direction, result_types, arg_types, args, b);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -57,8 +57,12 @@ void populateHLOToLHLOConversionPattern(MLIRContext *context,
|
|||
|
||||
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
|
||||
void populateHLOToLinalgConversionPattern(MLIRContext *context,
|
||||
TypeConverter &typeConverter,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
// Converter to signless intergers to be used with linalg conversion patterns.
|
||||
std::unique_ptr<TypeConverter> createHloToLinalgSignedIntegerConverter();
|
||||
|
||||
// Sets up legality definitions for materializing broadcasts.
|
||||
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
|
||||
ConversionTarget *conversionTarget);
|
||||
|
|
|
@ -241,16 +241,15 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
|||
!(t.getElementType().isSignlessIntOrFloat() ||
|
||||
t.getElementType().isa<ComplexType>());
|
||||
};
|
||||
if (llvm::any_of(args,
|
||||
[&](Value v) {
|
||||
return fail(v.getType().dyn_cast<ShapedType>());
|
||||
}) ||
|
||||
llvm::any_of(op.getOperation()->getResultTypes(),
|
||||
[&](Type t) { return fail(t.dyn_cast<ShapedType>()); }))
|
||||
if (llvm::any_of(op.getOperation()->getResultTypes(), [&](Type t) {
|
||||
return fail(this->typeConverter->convertType(t)
|
||||
.template dyn_cast<ShapedType>());
|
||||
})) {
|
||||
return emitError(loc,
|
||||
"lhlo to linalg conversion expects ranked args of "
|
||||
"hlo to linalg conversion expects ranked args of "
|
||||
"signless int, float or complex element type with ")
|
||||
<< nloops << " parallel iterators: " << *(op.getOperation());
|
||||
}
|
||||
|
||||
// Construct the indexing maps needed for linalg.generic ops.
|
||||
SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
|
||||
|
@ -270,12 +269,12 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
|||
if (isLHLO) {
|
||||
output_buffers.append(args.begin() + num_inputs, args.end());
|
||||
} else {
|
||||
Value result = op.getOperation()->getResult(0);
|
||||
ShapedType result_type = result.getType().template cast<ShapedType>();
|
||||
Type result_type = this->typeConverter->convertType(
|
||||
op.getOperation()->getResult(0).getType());
|
||||
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
|
||||
output_buffers.push_back(
|
||||
GetInitTensor(rewriter, loc, result_type, dyn_sizes));
|
||||
op_result_types.push_back(result.getType());
|
||||
output_buffers.push_back(GetInitTensor(
|
||||
rewriter, loc, result_type.cast<ShapedType>(), dyn_sizes));
|
||||
op_result_types.push_back(result_type);
|
||||
}
|
||||
body_result_types = llvm::to_vector<4>(llvm::map_range(
|
||||
output_buffers, [](Value v) { return getElementTypeOrSelf(v); }));
|
||||
|
@ -1279,8 +1278,10 @@ class DynamicSliceConverter : public OpConversionPattern<mhlo::DynamicSliceOp> {
|
|||
// map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it
|
||||
// to an lmhlo op and call the lmhlo implementation.
|
||||
start_index = lmhlo::HloOpToStdScalarOp::map<lmhlo::ClampOp>(
|
||||
loc, start_index.getType(), ArrayRef<Value>{zero, start_index, ub},
|
||||
&rewriter);
|
||||
loc, start_index.getType(),
|
||||
ArrayRef<Type>{start_index.getType(), start_index.getType(),
|
||||
start_index.getType()},
|
||||
ArrayRef<Value>{zero, start_index, ub}, &rewriter);
|
||||
start_indices.push_back(
|
||||
rewriter.create<IndexCastOp>(loc, index_type, start_index)
|
||||
.getResult());
|
||||
|
@ -2073,6 +2074,7 @@ struct TorchIndexSelectOpOnTensorsConversion
|
|||
};
|
||||
|
||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
TypeConverter& typeConverter,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
||||
|
@ -2128,10 +2130,65 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
|
||||
SliceConverter<lmhlo::SliceOp>,
|
||||
TransposeConverter<lmhlo::TransposeOp>
|
||||
>(context);
|
||||
>(typeConverter, context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// Converter that turns signed/unsigned integers types into signless types.
|
||||
class RemoveSignTypeConverter : public TypeConverter {
|
||||
public:
|
||||
RemoveSignTypeConverter() {
|
||||
addConversion([](Type type) { return type; });
|
||||
|
||||
addConversion(convertInteger);
|
||||
addConversion(convertShapedType);
|
||||
|
||||
addArgumentMaterialization(materializeCastFromIllegal);
|
||||
addSourceMaterialization(materializeCastToIllegal);
|
||||
addTargetMaterialization(materializeCastFromIllegal);
|
||||
}
|
||||
|
||||
private:
|
||||
static Type convertInteger(IntegerType int_type) {
|
||||
return IntegerType::get(int_type.getContext(),
|
||||
int_type.getIntOrFloatBitWidth());
|
||||
}
|
||||
|
||||
static Type convertShapedType(ShapedType shaped_type) {
|
||||
if (auto int_type = shaped_type.getElementType().dyn_cast<IntegerType>())
|
||||
return shaped_type.clone(convertInteger(int_type));
|
||||
return shaped_type;
|
||||
}
|
||||
|
||||
static llvm::Optional<Value> materializeCastFromIllegal(OpBuilder& builder,
|
||||
Type type,
|
||||
ValueRange inputs,
|
||||
Location loc) {
|
||||
Type from_type = getElementTypeOrSelf(inputs[0].getType());
|
||||
Type to_type = getElementTypeOrSelf(type);
|
||||
if ((!from_type.isSignedInteger() && !from_type.isUnsignedInteger()) ||
|
||||
!to_type.isSignlessInteger())
|
||||
return llvm::None;
|
||||
// Use unrealized conversion casts to do signful->signless conversions.
|
||||
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs[0])
|
||||
->getResult(0);
|
||||
}
|
||||
|
||||
static llvm::Optional<Value> materializeCastToIllegal(OpBuilder& builder,
|
||||
Type type,
|
||||
ValueRange inputs,
|
||||
Location loc) {
|
||||
Type from_type = getElementTypeOrSelf(inputs[0].getType());
|
||||
Type to_type = getElementTypeOrSelf(type);
|
||||
if (!from_type.isSignlessInteger() ||
|
||||
(!to_type.isSignedInteger() && !to_type.isUnsignedInteger()))
|
||||
return llvm::None;
|
||||
// Use unrealized conversion casts to do signless->signful conversions.
|
||||
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs[0])
|
||||
->getResult(0);
|
||||
}
|
||||
};
|
||||
|
||||
// Converts LHLO ops to Linalg generic.
|
||||
// Sample result for lmhlo::AddOp.
|
||||
//
|
||||
|
@ -2163,9 +2220,12 @@ struct LhloLegalizeToLinalgPass
|
|||
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
||||
math::MathDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect, AffineDialect>();
|
||||
target.addLegalOp<UnrealizedConversionCastOp>();
|
||||
|
||||
RemoveSignTypeConverter type_converter;
|
||||
auto func = getFunction();
|
||||
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
populateLHLOToLinalgConversionPattern(func.getContext(), type_converter,
|
||||
&patterns);
|
||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
|
@ -2181,17 +2241,20 @@ struct HloLegalizeToLinalgPass
|
|||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns(&getContext());
|
||||
ConversionTarget target(getContext());
|
||||
MLIRContext& ctx = getContext();
|
||||
OwningRewritePatternList patterns(&ctx);
|
||||
ConversionTarget target(ctx);
|
||||
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
||||
math::MathDialect, StandardOpsDialect,
|
||||
tensor::TensorDialect, scf::SCFDialect>();
|
||||
|
||||
// TODO: DimOp shouldn't be in MemRefDialect
|
||||
target.addLegalOp<memref::DimOp>();
|
||||
target.addLegalOp<UnrealizedConversionCastOp>();
|
||||
|
||||
RemoveSignTypeConverter type_converter;
|
||||
auto func = getFunction();
|
||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
mhlo::populateHLOToLinalgConversionPattern(&ctx, type_converter, &patterns);
|
||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
|
@ -2209,6 +2272,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
|
|||
namespace mhlo {
|
||||
|
||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
TypeConverter& type_converter,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
|
@ -2272,7 +2336,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
ReduceOnTensorsConversion,
|
||||
ReduceWindowOpOnTensorsConversion,
|
||||
TorchIndexSelectOpOnTensorsConversion,
|
||||
PadOpOnTensorsConversion>(context);
|
||||
PadOpOnTensorsConversion>(type_converter, context);
|
||||
// clang-format on
|
||||
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
||||
ReduceRegionXLAOpConversion<mhlo::MinOp>,
|
||||
|
@ -2287,5 +2351,10 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
return std::make_unique<HloLegalizeToLinalgPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<TypeConverter> createHloToLinalgSignedIntegerConverter() {
|
||||
return std::make_unique<RemoveSignTypeConverter>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -603,6 +603,21 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @maxu
|
||||
func @maxu(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
|
||||
%0 = "mhlo.maximum"(%lhs, %rhs)
|
||||
: (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32>
|
||||
return %0 : tensor<2x2xui32>
|
||||
}
|
||||
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xi32>
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32):
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi ugt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()>
|
||||
// CHECK-LABEL: func @add_scalar
|
||||
func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
|
||||
|
@ -2196,3 +2211,43 @@ func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>)
|
|||
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %concat : tensor<?x?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: unsigned_divide
|
||||
func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: divi_unsigned
|
||||
%0 = "mhlo.divide"(%lhs, %rhs) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32>
|
||||
return %0 : tensor<2x2xui32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: unsigned_remainder
|
||||
func @unsigned_remainder(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: remi_unsigned
|
||||
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32>
|
||||
return %0 : tensor<2x2xui32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: unsigned_convert
|
||||
func @unsigned_convert(%in: tensor<2x2xui32>) -> tensor<2x2xui64> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: zexti
|
||||
%0 = "mhlo.convert"(%in) : (tensor<2x2xui32>) -> tensor<2x2xui64>
|
||||
return %0 : tensor<2x2xui64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: unsigned_compare
|
||||
func @unsigned_compare(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xi1> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cmpi ugt
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xi1>
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
|
|
|
@ -93,6 +93,21 @@ func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @maxu
|
||||
func @maxu(%lhs: memref<2x2xui32>, %rhs: memref<2x2xui32>,
|
||||
%result: memref<2x2xui32>) {
|
||||
"lmhlo.maximum"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xui32>, memref<2x2xui32>, memref<2x2xui32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i32):
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi ugt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @and
|
||||
func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||
%result: memref<2x2xi32>) {
|
||||
|
|
Loading…
Reference in New Issue