From a847109ac74e39f7ca80e6fe619ad1aaf20d86fb Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 25 May 2021 04:34:24 -0700 Subject: [PATCH] Support complex types when converting HLO multiply op. We can lower it to the MulOp in the complex dialect. PiperOrigin-RevId: 375675079 --- .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 14 ++++++++++++++ tests/hlo-legalize-to-linalg.mlir | 13 +++++++++++++ tests/lhlo-legalize-to-linalg.mlir | 14 ++++++++++++++ 3 files changed, 41 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index a4169eb..0a3c240 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -65,6 +65,7 @@ struct LhloToScalarOp { using FOp = ::mlir::MulFOp; using IOp = ::mlir::MulIOp; using UOp = ::mlir::MulIOp; + using COp = ::mlir::complex::MulOp; }; template <> struct LhloToScalarOp { @@ -631,6 +632,19 @@ inline Value MapLhloOpToStdScalarOp(Location loc, args, loc, b); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef arg_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToScalarOpImpl, + isUnsignedIntegerType, ScalarUOp, + isFloatType, ScalarFOp, + isComplexType, ScalarCOp>{}( + loc, result_types, arg_types, args, b); +} + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 79543e6..08c158e 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -65,6 +65,19 @@ func @integer_mul(%lhs: tensor<2x2xi32>, // ----- +// CHECK-LABEL: func @complex_mul +func @complex_mul(%lhs: tensor<2x2xcomplex>, + %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { + // CHECK: linalg.generic + // CHECK: complex.mul + %0 = "mhlo.multiply"(%lhs, %rhs) + : (tensor<2x2xcomplex>, tensor<2x2xcomplex>) + -> tensor<2x2xcomplex> + return %0 : tensor<2x2xcomplex> +} + +// ----- + // CHECK-LABEL: func @float_remainder func @float_remainder(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index dc96182..d980782 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -242,6 +242,20 @@ func @complex_divide(%lhs: memref<2xcomplex>, %rhs: memref<2xcomplex>, // ----- +// CHECK-LABEL: func @complex_multiply +func @complex_multiply(%lhs: memref<2xcomplex>, %rhs: memref<2xcomplex>, + %result: memref<2xcomplex>) { + "lmhlo.multiply"(%lhs, %rhs, %result) + : (memref<2xcomplex>, memref<2xcomplex>, memref<2xcomplex>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex, %[[RHS_IN:.*]]: complex, %[[RESULT_OUT:.*]]: complex): +// CHECK-NEXT: %[[RESULT:.*]] = complex.mul %[[LHS_IN]], %[[RHS_IN]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : complex + +// ----- + // CHECK-LABEL: func @select func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {