diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 7aa956f..c9a12ab 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -323,6 +323,10 @@ inline Value MapLhloOpToStdScalarOp( // No conversion is needed for the same width integers return args.front(); } + if (targetType.isInteger(/*width=*/1)) { + auto zero = b->create(loc, b->getFloatAttr(sourceType, 0.0)); + return b->create(loc, CmpFPredicate::UNE, args.front(), zero); + } if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) { return b->create(loc, result_types, args, mlir::None); } diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 59e7255..dda3ff5 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -727,6 +727,20 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @convert_f32_to_i1 +func @convert_f32_to_i1(%input: tensor<2x2xf32>) -> tensor<2x2xi1> { + %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1> + return %result : tensor<2x2xi1> +} +// CHECK: linalg.init_tensor +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %{{.*}}: i1): +// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RESULT:.*]] = cmpf une, %[[OPERAND_IN]], %[[ZERO]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @convert_f32_to_i32 func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index e0ab3e2..39a6d72 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -502,6 +502,20 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @convert_f32_to_i1 +func @convert_f32_to_i1(%input: memref<2x2xf32>, %result: memref<2x2xi1>) { + "lmhlo.convert"(%input, %result) + : (memref<2x2xf32>, memref<2x2xi1>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i1): +// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RESULT:.*]] = cmpf une, %[[OPERAND_IN]], %[[ZERO]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @convert_f32_to_i32 func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { "lmhlo.convert"(%input, %result)