diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 636cd8c..7ce33fb 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -288,7 +288,12 @@ inline Value MapLhloOpToStdScalarOp( Type sourceType = getElementTypeOrSelf(args.front().getType()); Type targetType = getElementTypeOrSelf(result_types.front()); - if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { + // A boolean value is considered to be unsigned when converting to + // floating-point. Otherwise, it will become `-1`. + if (sourceType.isInteger(/*width=*/1) && + mlir::UIToFPOp::areCastCompatible(sourceType, targetType)) { + return b->create(loc, result_types, args, mlir::None); + } else if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { return b->create(loc, result_types, args, mlir::None); } else if (sourceType.isa() && targetType.isa()) { FloatType src = sourceType.cast(); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 159a616..4b2f354 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -608,6 +608,19 @@ func @reshape_multiple_collapse // ----- +// CHECK-LABEL: func @convert_i1_to_f32 +func @convert_i1_to_f32(%input: tensor<2x2xi1>) -> tensor<2x2xf32> { + %result = "mhlo.convert"(%input) : (tensor<2x2xi1>) -> tensor<2x2xf32> + return %result : tensor<2x2xf32> +} +// CHECK: linalg.init_tensor +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i1, %{{.*}}: f32): +// CHECK-NEXT: %[[RESULT:.*]] = uitofp %[[OPERAND_IN]] : i1 to f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + // CHECK-LABEL: func @convert_i32_to_f32 func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 1180847..3c75795 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -404,6 +404,18 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @convert_i1_to_f32 +func @convert_i1_to_f32(%input: memref<2x2xi1>, %result: memref<2x2xf32>) { + "lmhlo.convert"(%input, %result) : (memref<2x2xi1>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i1, %[[RESULT_OUT:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = uitofp %[[OPERAND_IN]] : i1 to f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + // CHECK-LABEL: func @convert_i32_to_f32 func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()