Lower LHLO::AbsOp to complex dialect.

Also fix the traits for LHLO::AbsOp to allow different types and add a
verifier.

PiperOrigin-RevId: 370438790
This commit is contained in:
Adrian Kuegel 2021-04-26 05:42:39 -07:00 committed by TensorFlow MLIR Team
parent 1c11075d62
commit 0e2b255f01
5 changed files with 111 additions and 51 deletions

View File

@ -77,7 +77,10 @@ class LHLO_UnaryElementwiseOp<string mnemonic,
Arg<BufferType, "", [MemWrite]>:$output);
}
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
// Abs supports complex to real, so element type is not guaranteed to match.
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_AbsOp {
let verifier = [{ return Verify(*this); }];
}
// TODO(timshen): add a custom verifier.
def LHLO_BitcastConvertOp:

View File

@ -84,7 +84,7 @@ template <typename LhloOp>
using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
template <typename... Args>
struct MapLhloOpToStdScalarOpImpl {
struct MapLhloOpToScalarOpImpl {
Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return nullptr;
@ -92,7 +92,7 @@ struct MapLhloOpToStdScalarOpImpl {
};
template <typename StdScalarOp>
struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
struct MapLhloOpToScalarOpImpl<StdScalarOp> {
Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
@ -100,7 +100,7 @@ struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
};
template <typename SupportedType, typename StdScalarOp, typename... Args>
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
struct MapLhloOpToScalarOpImpl<SupportedType, StdScalarOp, Args...> {
Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
Type element_type = getElementTypeOrSelf(args.front().getType());
@ -108,7 +108,7 @@ struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
return b->template create<StdScalarOp>(loc, result_types, args,
mlir::None);
}
return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
return MapLhloOpToScalarOpImpl<Args...>{}(loc, result_types, args, b);
}
};
@ -117,9 +117,9 @@ struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
template <typename LhloOpTy>
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
ScalarFOp<LhloOpTy>>{}(loc, result_types,
args, b);
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
ScalarFOp<LhloOpTy>>{}(loc, result_types, args,
b);
}
template <>
@ -129,7 +129,11 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
OpBuilder* b) {
Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
loc, result_types, args, b);
}
if (element_type.isa<ComplexType>()) {
return MapLhloOpToScalarOpImpl<ComplexType, ::mlir::complex::AbsOp>{}(
loc, result_types, args, b);
}
if (element_type.isa<IntegerType>()) {
@ -154,7 +158,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
FloatType, ScalarFOp<lmhlo::AddOp>,
ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
loc, result_types, args, b);
@ -165,7 +169,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
loc, result_types, args, b);
}
@ -174,7 +178,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Atan2Op>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::Atan2Op>{}(
loc, result_types, args, b);
}
@ -247,7 +251,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::ExpOp>{}(
loc, result_types, args, b);
}
@ -256,7 +260,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}(
loc, result_types, args, b);
}
@ -265,7 +269,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
loc, result_types, args, b);
}
@ -273,8 +277,8 @@ template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<complex::CreateOp>{}(loc, result_types,
args, b);
return MapLhloOpToScalarOpImpl<complex::CreateOp>{}(loc, result_types, args,
b);
}
template <>
@ -282,8 +286,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<complex::ReOp>{}(loc, result_types, args,
b);
return MapLhloOpToScalarOpImpl<complex::ReOp>{}(loc, result_types, args, b);
}
template <>
@ -291,8 +294,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<complex::ImOp>{}(loc, result_types, args,
b);
return MapLhloOpToScalarOpImpl<complex::ImOp>{}(loc, result_types, args, b);
}
template <>
@ -373,15 +375,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
const auto& result = args[2];
Type element_type = lhs.getType();
if (element_type.isa<FloatType>()) {
Value float_mul = MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::MulFOp>{}(
Value float_mul = MapLhloOpToScalarOpImpl<FloatType, ::mlir::MulFOp>{}(
loc, result_types, {lhs, rhs}, b);
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AddFOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::AddFOp>{}(
loc, result_types, {float_mul, result}, b);
}
if (element_type.isa<IntegerType>()) {
Value int_mul = MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::MulIOp>{}(
Value int_mul = MapLhloOpToScalarOpImpl<IntegerType, ::mlir::MulIOp>{}(
loc, result_types, {lhs, rhs}, b);
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AddIOp>{}(
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::AddIOp>{}(
loc, result_types, {int_mul, result}, b);
}
return nullptr;
@ -392,7 +394,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::CosOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::CosOp>{}(
loc, result_types, args, b);
}
@ -401,7 +403,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SinOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::SinOp>{}(
loc, result_types, args, b);
}
@ -410,7 +412,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::FloorFOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::FloorFOp>{}(
loc, result_types, args, b);
}
@ -468,7 +470,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::LogOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::LogOp>{}(
loc, result_types, args, b);
}
@ -507,7 +509,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Log1pOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::Log1pOp>{}(
loc, result_types, args, b);
}
@ -563,7 +565,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
OpBuilder* b) {
Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
loc, result_types, args, b);
}
if (element_type.isa<IntegerType>()) {
@ -604,8 +606,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}(
loc, result_types, args, b);
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::OrOp>{}(loc, result_types,
args, b);
}
template <>
@ -613,7 +615,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::RsqrtOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::RsqrtOp>{}(
loc, result_types, args, b);
}
@ -627,7 +629,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
// Floating point can use std::powf
auto result_type = result_types.front();
if (result_type.isa<::mlir::FloatType>())
return MapLhloOpToStdScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types,
return MapLhloOpToScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types,
args, b);
assert(result_type.isa<::mlir::IntegerType>() &&
@ -699,7 +701,7 @@ template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
return MapLhloOpToScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
b);
}
@ -707,7 +709,7 @@ template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
return MapLhloOpToScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}(
loc, result_types, args, b);
}
@ -715,7 +717,7 @@ template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
return MapLhloOpToScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}(
loc, result_types, args, b);
}
@ -723,7 +725,7 @@ template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
return MapLhloOpToScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}(
loc, result_types, args, b);
}
@ -780,7 +782,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SqrtOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::SqrtOp>{}(
loc, result_types, args, b);
}
@ -789,7 +791,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
return MapLhloOpToScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
FloatType, ScalarFOp<lmhlo::SubOp>,
ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
loc, result_types, args, b);
@ -800,7 +802,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::TanhOp>{}(
return MapLhloOpToScalarOpImpl<FloatType, ::mlir::math::TanhOp>{}(
loc, result_types, args, b);
}
@ -809,7 +811,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
return MapLhloOpToScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
loc, result_types, args, b);
}

View File

@ -62,6 +62,30 @@ LmhloDialect::LmhloDialect(MLIRContext* context)
>();
}
//===----------------------------------------------------------------------===//
// AbsOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(AbsOp op) {
auto operand_type = getElementTypeOrSelf(op.input().getType());
auto output_type = getElementTypeOrSelf(op.output().getType());
if (auto complex_type = operand_type.dyn_cast<ComplexType>()) {
if (complex_type.getElementType() != output_type) {
return op.emitOpError(
"requires output type to be the same as the element type of the "
"input");
}
return success();
}
if (operand_type != output_type)
return op.emitOpError("requires all operands to have the same type");
return success();
}
//===----------------------------------------------------------------------===//
// AllToAllOp
//===----------------------------------------------------------------------===//
// Verifies replica groups attached to collective communication operations.
// If the attribute is not empty, it must be a rank 2 tensor, and each replica
// should appear exactly once. If `is_uniform_sized` is true, then we also check
@ -120,8 +144,8 @@ static LogicalResult Verify(AllReduceOp op) {
if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false)))
return failure();
// AllReduce had variadic operands and results that have the same size.
// Each memeber of the operand should have the same type as the corresponding
// AllReduce has variadic operands and results that have the same size.
// Each member of the operand should have the same type as the corresponding
// member of the result.
for (auto it : llvm::enumerate(
llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {

View File

@ -378,6 +378,20 @@ func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// -----
// CHECK-LABEL: func @complex_abs
func @complex_abs(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
"lmhlo.abs"(%input, %result)
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[ABS_OUT:.*]]: f32):
// CHECK-NEXT: %[[ABS:.*]] = complex.abs %[[CPLX_IN:.*]] : complex<f32>
// CHECK-NEXT: linalg.yield %[[ABS]] : f32
// -----
// CHECK-LABEL: func @absi
func @absi(%input: memref<2x2xi32>,
%result: memref<2x2xi32>) {

View File

@ -1151,3 +1151,20 @@ func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
func @invalid_complex_abs_call(%input:memref<2xcomplex<f32>>, %result:memref<2xcomplex<f32>>) -> () {
// expected-error @+1 {{requires output type to be the same as the element type of the input}}
"lmhlo.abs"(%input, %result)
: (memref<2xcomplex<f32>>, memref<2xcomplex<f32>>) -> ()
return
}
// -----
func @invalid_float_abs_call(%input:memref<2xf32>, %result:memref<2xf64>) -> () {
// expected-error @+1 {{requires all operands to have the same type}}
"lmhlo.abs"(%input, %result) : (memref<2xf32>, memref<2xf64>) -> ()
return
}