diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index c153031..d0cbe59 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -887,6 +887,8 @@ inline Value MapLhloOpToStdScalarOp(Location loc, b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); + } else if (element_type.isa()) { + return b->create<::mlir::complex::SignOp>(loc, element_type, args.front()); } return nullptr; } diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index e3f0f37..0a08562 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -256,6 +256,18 @@ func @complex_neg(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { // ----- +// CHECK-LABEL: func @complex_sign +func @complex_sign( + %arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { + // CHECK: linalg.generic + // CHECK: complex.sign + %0 = "mhlo.sign"(%arg0) : (tensor<2x2xcomplex>) + -> tensor<2x2xcomplex> + return %0 : tensor<2x2xcomplex> +} + +// ----- + // CHECK-LABEL: func @float_tanh func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 3fd89b4..f8d7dea 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -810,6 +810,20 @@ func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) { // ----- +// CHECK-LABEL: func @sign_complex +func @sign_complex(%input: memref<2x2xcomplex>, + %result: memref<2x2xcomplex>) { + "lmhlo.sign"(%input, %result) : (memref<2x2xcomplex>, + memref<2x2xcomplex>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: complex, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = complex.sign %[[OPERAND_IN]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : complex + +// ----- + // CHECK-LABEL: func @sqrt func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()