Fix constant folding of mhlo.convert op with i1 element types
Boolean element values should be fetched as an unsigned integer and not signed integer which would return -1 for true. Added to a TODO to handle unsigned types correctly as well as we don't seem to be using unsigned types. PiperOrigin-RevId: 343927564
This commit is contained in:
parent
2ac41d8cd2
commit
b016b5a219
|
@ -61,12 +61,16 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
|||
// mapValues always takes a function returning APInt, even when the output
|
||||
// is actually float.
|
||||
using func_type = llvm::APInt(const llvm::APInt&);
|
||||
|
||||
// TODO(hinsu): Correctly handle unsigned element types.
|
||||
bool is_bool = old_type.isInteger(1);
|
||||
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
|
||||
// Int -> Float
|
||||
return elements.mapValues(
|
||||
new_type, llvm::function_ref<func_type>([&newFloatType](
|
||||
new_type, llvm::function_ref<func_type>([&newFloatType, &is_bool](
|
||||
const llvm::APInt& intVal) {
|
||||
llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
|
||||
int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
|
||||
llvm::APFloat newDouble(static_cast<double>(val));
|
||||
bool loses_info = false;
|
||||
newDouble.convert(newFloatType.getFloatSemantics(),
|
||||
llvm::APFloat::rmNearestTiesToEven, &loses_info);
|
||||
|
@ -76,9 +80,10 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
|||
// new_type is Integer
|
||||
// Int -> Int
|
||||
return elements.mapValues(
|
||||
new_type,
|
||||
llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) {
|
||||
return llvm::APInt(bit_width, intVal.getSExtValue());
|
||||
new_type, llvm::function_ref<func_type>([&bit_width, &is_bool](
|
||||
const llvm::APInt& intVal) {
|
||||
int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
|
||||
return llvm::APInt(bit_width, val);
|
||||
}));
|
||||
}
|
||||
|
||||
|
|
|
@ -123,6 +123,17 @@ func @const_int_bf16() -> tensor<bf16> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_bool_f32
|
||||
func @const_bool_f32() -> tensor<2xf32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>
|
||||
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xf32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_bf16_int
|
||||
func @const_bf16_int() -> tensor<i16> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
|
||||
|
@ -145,8 +156,8 @@ func @const_int_narrowing() -> tensor<i32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_int_widening
|
||||
func @const_int_widening() -> tensor<i64> {
|
||||
// CHECK-LABEL: func @const_bool_widening
|
||||
func @const_bool_widening() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
|
@ -156,6 +167,17 @@ func @const_int_widening() -> tensor<i64> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_int_widening
|
||||
func @const_int_widening() -> tensor<2xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0, 1]> : tensor<2xi32>
|
||||
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_negative_int_widening
|
||||
func @const_negative_int_widening() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
|
||||
|
|
Loading…
Reference in New Issue