Fix constant folding of mhlo.convert op with i1 element types
Boolean element values should be fetched as an unsigned integer and not signed integer which would return -1 for true. Added to a TODO to handle unsigned types correctly as well as we don't seem to be using unsigned types. PiperOrigin-RevId: 343927564
This commit is contained in:
parent
2ac41d8cd2
commit
b016b5a219
|
@ -61,12 +61,16 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||||
// mapValues always takes a function returning APInt, even when the output
|
// mapValues always takes a function returning APInt, even when the output
|
||||||
// is actually float.
|
// is actually float.
|
||||||
using func_type = llvm::APInt(const llvm::APInt&);
|
using func_type = llvm::APInt(const llvm::APInt&);
|
||||||
|
|
||||||
|
// TODO(hinsu): Correctly handle unsigned element types.
|
||||||
|
bool is_bool = old_type.isInteger(1);
|
||||||
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
|
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
|
||||||
// Int -> Float
|
// Int -> Float
|
||||||
return elements.mapValues(
|
return elements.mapValues(
|
||||||
new_type, llvm::function_ref<func_type>([&newFloatType](
|
new_type, llvm::function_ref<func_type>([&newFloatType, &is_bool](
|
||||||
const llvm::APInt& intVal) {
|
const llvm::APInt& intVal) {
|
||||||
llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
|
int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
|
||||||
|
llvm::APFloat newDouble(static_cast<double>(val));
|
||||||
bool loses_info = false;
|
bool loses_info = false;
|
||||||
newDouble.convert(newFloatType.getFloatSemantics(),
|
newDouble.convert(newFloatType.getFloatSemantics(),
|
||||||
llvm::APFloat::rmNearestTiesToEven, &loses_info);
|
llvm::APFloat::rmNearestTiesToEven, &loses_info);
|
||||||
|
@ -76,9 +80,10 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||||
// new_type is Integer
|
// new_type is Integer
|
||||||
// Int -> Int
|
// Int -> Int
|
||||||
return elements.mapValues(
|
return elements.mapValues(
|
||||||
new_type,
|
new_type, llvm::function_ref<func_type>([&bit_width, &is_bool](
|
||||||
llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) {
|
const llvm::APInt& intVal) {
|
||||||
return llvm::APInt(bit_width, intVal.getSExtValue());
|
int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
|
||||||
|
return llvm::APInt(bit_width, val);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -123,6 +123,17 @@ func @const_int_bf16() -> tensor<bf16> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @const_bool_f32
|
||||||
|
func @const_bool_f32() -> tensor<2xf32> {
|
||||||
|
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>
|
||||||
|
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
|
||||||
|
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xf32>
|
||||||
|
// CHECK-NEXT: return [[CST]]
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @const_bf16_int
|
// CHECK-LABEL: func @const_bf16_int
|
||||||
func @const_bf16_int() -> tensor<i16> {
|
func @const_bf16_int() -> tensor<i16> {
|
||||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
|
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
|
||||||
|
@ -145,8 +156,8 @@ func @const_int_narrowing() -> tensor<i32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @const_int_widening
|
// CHECK-LABEL: func @const_bool_widening
|
||||||
func @const_int_widening() -> tensor<i64> {
|
func @const_bool_widening() -> tensor<i64> {
|
||||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
||||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||||
|
@ -156,6 +167,17 @@ func @const_int_widening() -> tensor<i64> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @const_int_widening
|
||||||
|
func @const_int_widening() -> tensor<2xi32> {
|
||||||
|
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0, 1]> : tensor<2xi32>
|
||||||
|
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
|
||||||
|
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xi32>
|
||||||
|
// CHECK-NEXT: return [[CST]]
|
||||||
|
return %0 : tensor<2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @const_negative_int_widening
|
// CHECK-LABEL: func @const_negative_int_widening
|
||||||
func @const_negative_int_widening() -> tensor<i64> {
|
func @const_negative_int_widening() -> tensor<i64> {
|
||||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
|
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
|
||||||
|
|
Loading…
Reference in New Issue