diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 288e04c..c153031 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -671,8 +671,9 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef args, OpBuilder* b) { Type element_type = getElementTypeOrSelf(args.front().getType()); - if (element_type.isa()) { - return MapLhloOpToScalarOpImpl{}( + if (element_type.isa()) { + return MapLhloOpToScalarOpImpl{}( loc, result_types, arg_types, args, b); } if (element_type.isa()) { diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 409a648..e3f0f37 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -245,6 +245,17 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @complex_neg +func @complex_neg(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { + // CHECK: linalg.generic + // CHECK: complex.neg + %0 = "mhlo.negate"(%arg0) : (tensor<2x2xcomplex>) + -> tensor<2x2xcomplex> + return %0 : tensor<2x2xcomplex> +} + +// ----- + // CHECK-LABEL: func @float_tanh func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index a7b0e1d..3fd89b4 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -690,6 +690,22 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @complex_neg +func @complex_neg(%input: memref<2x2xcomplex>, + %result: memref<2x2xcomplex>) { + "lmhlo.negate"(%input, %result) : (memref<2x2xcomplex>, + memref<2x2xcomplex>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: complex, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = complex.neg %[[OPERAND_IN]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : complex + +// ----- + +// ----- + // CHECK-LABEL: func @negi func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { "lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()