From 581692025816ddfeeb0acd2a5a8405e1ed1e6a28 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 25 May 2021 03:42:25 -0700 Subject: [PATCH] Support complex types when converting HLO divide op. We can lower it to the DivOp in the complex dialect. Also add tests to hlo-legalize-to-linalg.mlir for CompareOp lowering of complex types. These were forgotten in a previous commit. PiperOrigin-RevId: 375669125 --- .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 15 +++++++ tests/hlo-legalize-to-linalg.mlir | 41 +++++++++++++++++++ tests/lhlo-legalize-to-linalg.mlir | 18 +++++++- 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 7b18b4f..a4169eb 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -58,6 +58,7 @@ struct LhloToScalarOp { using FOp = ::mlir::DivFOp; using IOp = ::mlir::SignedDivIOp; using UOp = ::mlir::UnsignedDivIOp; + using COp = ::mlir::complex::DivOp; }; template <> struct LhloToScalarOp { @@ -192,6 +193,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, } return nullptr; } + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, @@ -301,6 +303,19 @@ inline Value MapLhloOpToStdScalarOp(Location loc, return args.front(); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef arg_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToScalarOpImpl, + isUnsignedIntegerType, ScalarUOp, + isFloatType, ScalarFOp, + isComplexType, ScalarCOp>{}( + loc, result_types, arg_types, args, b); +} + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index f46cc97..79543e6 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -323,6 +323,36 @@ func @int_cmp(%lhs: tensor<2x2xi32>, // ----- +// CHECK-LABEL: func @complex_cmp_eq +func @complex_cmp_eq(%lhs: tensor<2xcomplex>, + %rhs: tensor<2xcomplex>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} + : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xi1>) + return %0 : tensor<2xi1> +} +// CHECK: linalg.init_tensor [2] : tensor<2xi1> +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex, %[[RHS_IN:.*]]: complex, %[[RESULT_OUT:.*]]: i1): +// CHECK-NEXT: %[[RESULT:.*]] = complex.eq %[[LHS_IN]], %[[RHS_IN]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + +// CHECK-LABEL: func @complex_cmp_neq +func @complex_cmp_neq(%lhs: tensor<2xcomplex>, + %rhs: tensor<2xcomplex>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"} + : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xi1>) + return %0 : tensor<2xi1> +} +// CHECK: linalg.init_tensor [2] : tensor<2xi1> +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex, %[[RHS_IN:.*]]: complex, %[[RESULT_OUT:.*]]: i1): +// CHECK-NEXT: %[[RESULT:.*]] = complex.neq %[[LHS_IN]], %[[RHS_IN]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @float_cos func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic @@ -2352,6 +2382,17 @@ func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor< // ----- +// CHECK-LABEL: complex_divide +func @complex_divide(%lhs: tensor<2xcomplex>, + %rhs: tensor<2xcomplex>) -> tensor<2xcomplex> { + // CHECK: linalg.generic + // CHECK: complex.div + %0 = "mhlo.divide"(%lhs, %rhs) : (tensor<2xcomplex>, tensor<2xcomplex>) -> tensor<2xcomplex> + return %0 : tensor<2xcomplex> +} + +// ----- + // CHECK-LABEL: unsigned_remainder func @unsigned_remainder(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> { // CHECK: linalg.generic diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index ab3dfea..dc96182 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -158,8 +158,8 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { // ----- -// CHECK-LABEL: func @is_finte -func @is_finte(%input: memref<2x2xf32>, %result: memref<2x2xi1>) { +// CHECK-LABEL: func @is_finite +func @is_finite(%input: memref<2x2xf32>, %result: memref<2x2xi1>) { "lmhlo.is_finite"(%input, %result) : (memref<2x2xf32>, memref<2x2xi1>) -> () return } @@ -228,6 +228,20 @@ func @complex_cmp_neq(%lhs: memref<2xcomplex>, %rhs: memref<2xcomplex> // ----- +// CHECK-LABEL: func @complex_divide +func @complex_divide(%lhs: memref<2xcomplex>, %rhs: memref<2xcomplex>, + %result: memref<2xcomplex>) { + "lmhlo.divide"(%lhs, %rhs, %result) + : (memref<2xcomplex>, memref<2xcomplex>, memref<2xcomplex>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex, %[[RHS_IN:.*]]: complex, %[[RESULT_OUT:.*]]: complex): +// CHECK-NEXT: %[[RESULT:.*]] = complex.div %[[LHS_IN]], %[[RHS_IN]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : complex + +// ----- + // CHECK-LABEL: func @select func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {