Lower LHLO::AbsOp to complex dialect.
Also fix the traits for LHLO::AbsOp to allow different types and add a verifier. PiperOrigin-RevId: 370438790
This commit is contained in:
parent
1c11075d62
commit
0e2b255f01
|
@ -77,7 +77,10 @@ class LHLO_UnaryElementwiseOp<string mnemonic,
|
||||||
Arg<BufferType, "", [MemWrite]>:$output);
|
Arg<BufferType, "", [MemWrite]>:$output);
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
|
// Abs supports complex to real, so element type is not guaranteed to match.
|
||||||
|
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_AbsOp {
|
||||||
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(timshen): add a custom verifier.
|
// TODO(timshen): add a custom verifier.
|
||||||
def LHLO_BitcastConvertOp:
|
def LHLO_BitcastConvertOp:
|
||||||
|
|
|
@ -84,7 +84,7 @@ template <typename LhloOp>
|
||||||
using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
|
using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
struct MapLhloOpToStdScalarOpImpl {
|
struct MapLhloOpToScalarOpImpl {
|
||||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -92,7 +92,7 @@ struct MapLhloOpToStdScalarOpImpl {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename StdScalarOp>
|
template <typename StdScalarOp>
|
||||||
struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
|
struct MapLhloOpToScalarOpImpl<StdScalarOp> {
|
||||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
|
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
|
||||||
|
@ -100,7 +100,7 @@ struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename SupportedType, typename StdScalarOp, typename... Args>
|
template <typename SupportedType, typename StdScalarOp, typename... Args>
|
||||||
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
struct MapLhloOpToScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||||
|
@ -108,7 +108,7 @@ struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||||
return b->template create<StdScalarOp>(loc, result_types, args,
|
return b->template create<StdScalarOp>(loc, result_types, args,
|
||||||
mlir::None);
|
mlir::None);
|
||||||
}
|
}
|
||||||
return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
|
return MapLhloOpToScalarOpImpl<Args...>{}(loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -117,9 +117,9 @@ struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||||
template <typename LhloOpTy>
|
template <typename LhloOpTy>
|
||||||
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
|
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
|
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
|
||||||
ScalarFOp<LhloOpTy>>{}(loc, result_types,
|
ScalarFOp<LhloOpTy>>{}(loc, result_types, args,
|
||||||
args, b);
|
b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -129,7 +129,11 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||||
if (element_type.isa<FloatType>()) {
|
if (element_type.isa<FloatType>()) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
if (element_type.isa<ComplexType>()) {
|
||||||
|
return MapLhloOpToScalarOpImpl<ComplexType, ::mlir::complex::AbsOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
@ -154,7 +158,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
|
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
|
||||||
FloatType, ScalarFOp<lmhlo::AddOp>,
|
FloatType, ScalarFOp<lmhlo::AddOp>,
|
||||||
ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
|
ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
|
@ -165,7 +169,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,7 +178,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Atan2Op>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::Atan2Op>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,7 +251,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::ExpOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,7 +260,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -265,7 +269,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -273,8 +277,8 @@ template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<complex::CreateOp>{}(loc, result_types,
|
return MapLhloOpToScalarOpImpl<complex::CreateOp>{}(loc, result_types, args,
|
||||||
args, b);
|
b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -282,8 +286,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<complex::ReOp>{}(loc, result_types, args,
|
return MapLhloOpToScalarOpImpl<complex::ReOp>{}(loc, result_types, args, b);
|
||||||
b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -291,8 +294,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<complex::ImOp>{}(loc, result_types, args,
|
return MapLhloOpToScalarOpImpl<complex::ImOp>{}(loc, result_types, args, b);
|
||||||
b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -373,15 +375,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
|
||||||
const auto& result = args[2];
|
const auto& result = args[2];
|
||||||
Type element_type = lhs.getType();
|
Type element_type = lhs.getType();
|
||||||
if (element_type.isa<FloatType>()) {
|
if (element_type.isa<FloatType>()) {
|
||||||
Value float_mul = MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::MulFOp>{}(
|
Value float_mul = MapLhloOpToScalarOpImpl<FloatType, ::mlir::MulFOp>{}(
|
||||||
loc, result_types, {lhs, rhs}, b);
|
loc, result_types, {lhs, rhs}, b);
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AddFOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::AddFOp>{}(
|
||||||
loc, result_types, {float_mul, result}, b);
|
loc, result_types, {float_mul, result}, b);
|
||||||
}
|
}
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
Value int_mul = MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::MulIOp>{}(
|
Value int_mul = MapLhloOpToScalarOpImpl<IntegerType, ::mlir::MulIOp>{}(
|
||||||
loc, result_types, {lhs, rhs}, b);
|
loc, result_types, {lhs, rhs}, b);
|
||||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AddIOp>{}(
|
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::AddIOp>{}(
|
||||||
loc, result_types, {int_mul, result}, b);
|
loc, result_types, {int_mul, result}, b);
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -392,7 +394,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::CosOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::CosOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -401,7 +403,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SinOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::SinOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -410,7 +412,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::FloorFOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::FloorFOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -468,7 +470,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::LogOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::LogOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -507,7 +509,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Log1pOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::Log1pOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -563,7 +565,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||||
if (element_type.isa<FloatType>()) {
|
if (element_type.isa<FloatType>()) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
@ -604,8 +606,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}(
|
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::OrOp>{}(loc, result_types,
|
||||||
loc, result_types, args, b);
|
args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -613,7 +615,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::RsqrtOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::RsqrtOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -627,7 +629,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
||||||
// Floating point can use std::powf
|
// Floating point can use std::powf
|
||||||
auto result_type = result_types.front();
|
auto result_type = result_types.front();
|
||||||
if (result_type.isa<::mlir::FloatType>())
|
if (result_type.isa<::mlir::FloatType>())
|
||||||
return MapLhloOpToStdScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types,
|
return MapLhloOpToScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types,
|
||||||
args, b);
|
args, b);
|
||||||
|
|
||||||
assert(result_type.isa<::mlir::IntegerType>() &&
|
assert(result_type.isa<::mlir::IntegerType>() &&
|
||||||
|
@ -699,7 +701,7 @@ template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
|
return MapLhloOpToScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
|
||||||
b);
|
b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -707,7 +709,7 @@ template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
|
return MapLhloOpToScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -715,7 +717,7 @@ template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
|
return MapLhloOpToScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -723,7 +725,7 @@ template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
|
return MapLhloOpToScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -780,7 +782,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SqrtOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::SqrtOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -789,7 +791,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
|
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
|
||||||
FloatType, ScalarFOp<lmhlo::SubOp>,
|
FloatType, ScalarFOp<lmhlo::SubOp>,
|
||||||
ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
|
ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
|
@ -800,7 +802,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::TanhOp>{}(
|
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::TanhOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -809,7 +811,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
|
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,30 @@ LmhloDialect::LmhloDialect(MLIRContext* context)
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AbsOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult Verify(AbsOp op) {
|
||||||
|
auto operand_type = getElementTypeOrSelf(op.input().getType());
|
||||||
|
auto output_type = getElementTypeOrSelf(op.output().getType());
|
||||||
|
if (auto complex_type = operand_type.dyn_cast<ComplexType>()) {
|
||||||
|
if (complex_type.getElementType() != output_type) {
|
||||||
|
return op.emitOpError(
|
||||||
|
"requires output type to be the same as the element type of the "
|
||||||
|
"input");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (operand_type != output_type)
|
||||||
|
return op.emitOpError("requires all operands to have the same type");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AllToAllOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Verifies replica groups attached to collective communication operations.
|
// Verifies replica groups attached to collective communication operations.
|
||||||
// If the attribute is not empty, it must be a rank 2 tensor, and each replica
|
// If the attribute is not empty, it must be a rank 2 tensor, and each replica
|
||||||
// should appear exactly once. If `is_uniform_sized` is true, then we also check
|
// should appear exactly once. If `is_uniform_sized` is true, then we also check
|
||||||
|
@ -120,8 +144,8 @@ static LogicalResult Verify(AllReduceOp op) {
|
||||||
if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false)))
|
if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// AllReduce had variadic operands and results that have the same size.
|
// AllReduce has variadic operands and results that have the same size.
|
||||||
// Each memeber of the operand should have the same type as the corresponding
|
// Each member of the operand should have the same type as the corresponding
|
||||||
// member of the result.
|
// member of the result.
|
||||||
for (auto it : llvm::enumerate(
|
for (auto it : llvm::enumerate(
|
||||||
llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {
|
llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {
|
||||||
|
|
|
@ -378,6 +378,20 @@ func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @complex_abs
|
||||||
|
func @complex_abs(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||||
|
"lmhlo.abs"(%input, %result)
|
||||||
|
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[ABS_OUT:.*]]: f32):
|
||||||
|
// CHECK-NEXT: %[[ABS:.*]] = complex.abs %[[CPLX_IN:.*]] : complex<f32>
|
||||||
|
// CHECK-NEXT: linalg.yield %[[ABS]] : f32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @absi
|
// CHECK-LABEL: func @absi
|
||||||
func @absi(%input: memref<2x2xi32>,
|
func @absi(%input: memref<2x2xi32>,
|
||||||
%result: memref<2x2xi32>) {
|
%result: memref<2x2xi32>) {
|
||||||
|
|
|
@ -1151,3 +1151,20 @@ func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
|
||||||
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @invalid_complex_abs_call(%input:memref<2xcomplex<f32>>, %result:memref<2xcomplex<f32>>) -> () {
|
||||||
|
// expected-error @+1 {{requires output type to be the same as the element type of the input}}
|
||||||
|
"lmhlo.abs"(%input, %result)
|
||||||
|
: (memref<2xcomplex<f32>>, memref<2xcomplex<f32>>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @invalid_float_abs_call(%input:memref<2xf32>, %result:memref<2xf64>) -> () {
|
||||||
|
// expected-error @+1 {{requires all operands to have the same type}}
|
||||||
|
"lmhlo.abs"(%input, %result) : (memref<2xf32>, memref<2xf64>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue