From 0e2b255f01a18ba1092e46f4fa47dacf478d1c5d Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 26 Apr 2021 05:42:39 -0700 Subject: [PATCH] Lower LHLO::AbsOp to complex dialect. Also fix the traits for LHLO::AbsOp to allow different types and add a verifier. PiperOrigin-RevId: 370438790 --- include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 5 +- .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 98 ++++++++++--------- lib/Dialect/mhlo/IR/lhlo_ops.cc | 28 +++++- tests/lhlo-legalize-to-linalg.mlir | 14 +++ tests/lhlo_ops.mlir | 17 ++++ 5 files changed, 111 insertions(+), 51 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index e7c7819..eb8c4d1 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -77,7 +77,10 @@ class LHLO_UnaryElementwiseOp:$output); } -def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp; +// Abs supports complex to real, so element type is not guaranteed to match. +def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_AbsOp { + let verifier = [{ return Verify(*this); }]; +} // TODO(timshen): add a custom verifier. def LHLO_BitcastConvertOp: diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 5a957e7..d6bbc17 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -84,7 +84,7 @@ template using ScalarCOp = typename LhloToScalarOp::COp; template -struct MapLhloOpToStdScalarOpImpl { +struct MapLhloOpToScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return nullptr; @@ -92,7 +92,7 @@ struct MapLhloOpToStdScalarOpImpl { }; template -struct MapLhloOpToStdScalarOpImpl { +struct MapLhloOpToScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return b->template create(loc, result_types, args, mlir::None); @@ -100,7 +100,7 @@ struct MapLhloOpToStdScalarOpImpl { }; template -struct MapLhloOpToStdScalarOpImpl { +struct MapLhloOpToScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { Type element_type = getElementTypeOrSelf(args.front().getType()); @@ -108,7 +108,7 @@ struct MapLhloOpToStdScalarOpImpl { return b->template create(loc, result_types, args, mlir::None); } - return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}(loc, result_types, args, b); } }; @@ -117,9 +117,9 @@ struct MapLhloOpToStdScalarOpImpl { template inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl, FloatType, - ScalarFOp>{}(loc, result_types, - args, b); + return MapLhloOpToScalarOpImpl, FloatType, + ScalarFOp>{}(loc, result_types, args, + b); } template <> @@ -129,7 +129,11 @@ inline Value MapLhloOpToStdScalarOp(Location loc, OpBuilder* b) { Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( + loc, result_types, args, b); + } + if (element_type.isa()) { + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } if (element_type.isa()) { @@ -154,9 +158,9 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl, - FloatType, ScalarFOp, - ComplexType, ScalarCOp>{}( + return MapLhloOpToScalarOpImpl, + FloatType, ScalarFOp, + ComplexType, ScalarCOp>{}( loc, result_types, args, b); } @@ -165,7 +169,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -174,7 +178,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -247,7 +251,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -256,7 +260,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -265,7 +269,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -273,8 +277,8 @@ template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}(loc, result_types, - args, b); + return MapLhloOpToScalarOpImpl{}(loc, result_types, args, + b); } template <> @@ -282,8 +286,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, - b); + return MapLhloOpToScalarOpImpl{}(loc, result_types, args, b); } template <> @@ -291,8 +294,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, - b); + return MapLhloOpToScalarOpImpl{}(loc, result_types, args, b); } template <> @@ -373,15 +375,15 @@ inline Value MapLhloOpToStdScalarOp(Location loc, const auto& result = args[2]; Type element_type = lhs.getType(); if (element_type.isa()) { - Value float_mul = MapLhloOpToStdScalarOpImpl{}( + Value float_mul = MapLhloOpToScalarOpImpl{}( loc, result_types, {lhs, rhs}, b); - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, {float_mul, result}, b); } if (element_type.isa()) { - Value int_mul = MapLhloOpToStdScalarOpImpl{}( + Value int_mul = MapLhloOpToScalarOpImpl{}( loc, result_types, {lhs, rhs}, b); - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, {int_mul, result}, b); } return nullptr; @@ -392,7 +394,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -401,7 +403,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -410,7 +412,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -468,7 +470,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -507,7 +509,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -563,7 +565,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, OpBuilder* b) { Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } if (element_type.isa()) { @@ -604,8 +606,8 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( - loc, result_types, args, b); + return MapLhloOpToScalarOpImpl{}(loc, result_types, + args, b); } template <> @@ -613,7 +615,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -627,8 +629,8 @@ inline Value MapLhloOpToStdScalarOp(Location loc, // Floating point can use std::powf auto result_type = result_types.front(); if (result_type.isa<::mlir::FloatType>()) - return MapLhloOpToStdScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types, - args, b); + return MapLhloOpToScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types, + args, b); assert(result_type.isa<::mlir::IntegerType>() && "only float and integer `pow` is supported right now"); @@ -699,15 +701,15 @@ template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, - b); + return MapLhloOpToScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, + b); } template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -715,7 +717,7 @@ template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -723,7 +725,7 @@ template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -780,7 +782,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -789,9 +791,9 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl, - FloatType, ScalarFOp, - ComplexType, ScalarCOp>{}( + return MapLhloOpToScalarOpImpl, + FloatType, ScalarFOp, + ComplexType, ScalarCOp>{}( loc, result_types, args, b); } @@ -800,7 +802,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } @@ -809,7 +811,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToScalarOpImpl{}( loc, result_types, args, b); } diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index af3d85a..1d6f5f4 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -62,6 +62,30 @@ LmhloDialect::LmhloDialect(MLIRContext* context) >(); } +//===----------------------------------------------------------------------===// +// AbsOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AbsOp op) { + auto operand_type = getElementTypeOrSelf(op.input().getType()); + auto output_type = getElementTypeOrSelf(op.output().getType()); + if (auto complex_type = operand_type.dyn_cast()) { + if (complex_type.getElementType() != output_type) { + return op.emitOpError( + "requires output type to be the same as the element type of the " + "input"); + } + return success(); + } + if (operand_type != output_type) + return op.emitOpError("requires all operands to have the same type"); + return success(); +} + +//===----------------------------------------------------------------------===// +// AllToAllOp +//===----------------------------------------------------------------------===// + // Verifies replica groups attached to collective communication operations. // If the attribute is not empty, it must be a rank 2 tensor, and each replica // should appear exactly once. If `is_uniform_sized` is true, then we also check @@ -120,8 +144,8 @@ static LogicalResult Verify(AllReduceOp op) { if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false))) return failure(); - // AllReduce had variadic operands and results that have the same size. - // Each memeber of the operand should have the same type as the corresponding + // AllReduce has variadic operands and results that have the same size. + // Each member of the operand should have the same type as the corresponding // member of the result. for (auto it : llvm::enumerate( llvm::zip(op.operands().getTypes(), op.results().getTypes()))) { diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 249e60c..1e0f3f4 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -378,6 +378,20 @@ func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @complex_abs +func @complex_abs(%input: memref<2x2xcomplex>, %result: memref<2x2xf32>) { + "lmhlo.abs"(%input, %result) + : (memref<2x2xcomplex>, memref<2x2xf32>) -> () + return +} + +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[ABS_OUT:.*]]: f32): +// CHECK-NEXT: %[[ABS:.*]] = complex.abs %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: linalg.yield %[[ABS]] : f32 + +// ----- + // CHECK-LABEL: func @absi func @absi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 6c5bd6a..97eec5d 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -1151,3 +1151,20 @@ func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { } : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } + +// ----- + +func @invalid_complex_abs_call(%input:memref<2xcomplex>, %result:memref<2xcomplex>) -> () { + // expected-error @+1 {{requires output type to be the same as the element type of the input}} + "lmhlo.abs"(%input, %result) + : (memref<2xcomplex>, memref<2xcomplex>) -> () + return +} + +// ----- + +func @invalid_float_abs_call(%input:memref<2xf32>, %result:memref<2xf64>) -> () { + // expected-error @+1 {{requires all operands to have the same type}} + "lmhlo.abs"(%input, %result) : (memref<2xf32>, memref<2xf64>) -> () + return +}