diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index bcef846..1dbbb3b 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -316,9 +316,13 @@ inline Value MapLhloOpToStdScalarOp( IntegerType res = targetType.cast(); if (src.getWidth() > res.getWidth()) { return b->create(loc, result_types, args, mlir::None); - } else if (src.getWidth() < res.getWidth()) { + } else if (src.getWidth() == 1) { + // Special case boolean values, so they get casted to `1` instead of `-1`. return b->create(loc, result_types, args, mlir::None); + } else if (src.getWidth() < res.getWidth()) { + return b->create(loc, result_types, args, + mlir::None); } // No conversion is needed for the same width integers return args.front(); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index dda3ff5..2612e5d 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -662,6 +662,19 @@ func @convert_i1_to_f32(%input: tensor<2x2xi1>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @convert_i1_to_i32 +func @convert_i1_to_i32(%input: tensor<2x2xi1>) -> tensor<2x2xi32> { + %result = "mhlo.convert"(%input) : (tensor<2x2xi1>) -> tensor<2x2xi32> + return %result : tensor<2x2xi32> +} +// CHECK: linalg.init_tensor +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i1, %{{.*}}: i32): +// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i1 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + // CHECK-LABEL: func @convert_i32_to_f32 func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> @@ -683,7 +696,7 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> { // CHECK: linalg.init_tensor // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %{{.*}}: i32): -// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32 +// CHECK-NEXT: %[[RESULT:.*]] = sexti %[[OPERAND_IN]] : i16 to i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 // ----- diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 39a6d72..66a580b 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -419,6 +419,18 @@ func @convert_i1_to_f32(%input: memref<2x2xi1>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @convert_i1_to_i32 +func @convert_i1_to_i32(%input: memref<2x2xi1>, %result: memref<2x2xi32>) { + "lmhlo.convert"(%input, %result) : (memref<2x2xi1>, memref<2x2xi32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i1, %[[RESULT_OUT:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i1 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + // CHECK-LABEL: func @convert_i32_to_f32 func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> () @@ -439,7 +451,7 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>, } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]: i32): -// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32 +// CHECK-NEXT: %[[RESULT:.*]] = sexti %[[OPERAND_IN]] : i16 to i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 // -----