diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 91ed8f0..7b18b4f 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -279,6 +279,16 @@ inline Value MapCompareOpToStdScalarOp(Location loc, return b->create>(loc, predicate.getValue(), lhs, rhs); } + if (auto complex_type = element_type.dyn_cast()) { + if (complex_type.getElementType().isa()) { + if (comparison_direction == "EQ") { + return b->create(loc, lhs, rhs); + } + if (comparison_direction == "NE") { + return b->create(loc, lhs, rhs); + } + } + } return nullptr; } diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 079838f..ab3dfea 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -188,7 +188,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @int_cmp func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, - %result: memref<2x2xi1>) { + %result: memref<2x2xi1>) { "lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> () return @@ -200,6 +200,34 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // ----- +// CHECK-LABEL: func @complex_cmp_eq +func @complex_cmp_eq(%lhs: memref<2xcomplex>, %rhs: memref<2xcomplex>, + %result: memref<2xi1>) { + "lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"} + : (memref<2xcomplex>, memref<2xcomplex>, memref<2xi1>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex, %[[RHS_IN:.*]]: complex, %[[RESULT_OUT:.*]]: i1): +// CHECK-NEXT: %[[RESULT:.*]] = complex.eq %[[LHS_IN]], %[[RHS_IN]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + +// CHECK-LABEL: func @complex_cmp_neq +func @complex_cmp_neq(%lhs: memref<2xcomplex>, %rhs: memref<2xcomplex>, + %result: memref<2xi1>) { + "lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "NE"} + : (memref<2xcomplex>, memref<2xcomplex>, memref<2xi1>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex, %[[RHS_IN:.*]]: complex, %[[RESULT_OUT:.*]]: i1): +// CHECK-NEXT: %[[RESULT:.*]] = complex.neq %[[LHS_IN]], %[[RHS_IN]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @select func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {