Fix constant folding of mhlo.convert op with i1 element types

Boolean element values should be fetched as an unsigned integer and not signed integer which would return -1 for true.

Added to a TODO to handle unsigned types correctly as well as we don't seem to be using unsigned types.

PiperOrigin-RevId: 343927564
This commit is contained in:
Smit Hinsu 2020-11-23 14:17:47 -08:00 committed by TensorFlow MLIR Team
parent 2ac41d8cd2
commit b016b5a219
2 changed files with 34 additions and 7 deletions

View File

@ -61,12 +61,16 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
// mapValues always takes a function returning APInt, even when the output // mapValues always takes a function returning APInt, even when the output
// is actually float. // is actually float.
using func_type = llvm::APInt(const llvm::APInt&); using func_type = llvm::APInt(const llvm::APInt&);
// TODO(hinsu): Correctly handle unsigned element types.
bool is_bool = old_type.isInteger(1);
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) { if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
// Int -> Float // Int -> Float
return elements.mapValues( return elements.mapValues(
new_type, llvm::function_ref<func_type>([&newFloatType]( new_type, llvm::function_ref<func_type>([&newFloatType, &is_bool](
const llvm::APInt& intVal) { const llvm::APInt& intVal) {
llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue())); int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
llvm::APFloat newDouble(static_cast<double>(val));
bool loses_info = false; bool loses_info = false;
newDouble.convert(newFloatType.getFloatSemantics(), newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven, &loses_info); llvm::APFloat::rmNearestTiesToEven, &loses_info);
@ -76,9 +80,10 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
// new_type is Integer // new_type is Integer
// Int -> Int // Int -> Int
return elements.mapValues( return elements.mapValues(
new_type, new_type, llvm::function_ref<func_type>([&bit_width, &is_bool](
llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) { const llvm::APInt& intVal) {
return llvm::APInt(bit_width, intVal.getSExtValue()); int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
return llvm::APInt(bit_width, val);
})); }));
} }

View File

@ -123,6 +123,17 @@ func @const_int_bf16() -> tensor<bf16> {
// ----- // -----
// CHECK-LABEL: func @const_bool_f32
func @const_bool_f32() -> tensor<2xf32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xf32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2xf32>
}
// -----
// CHECK-LABEL: func @const_bf16_int // CHECK-LABEL: func @const_bf16_int
func @const_bf16_int() -> tensor<i16> { func @const_bf16_int() -> tensor<i16> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16> // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
@ -145,8 +156,8 @@ func @const_int_narrowing() -> tensor<i32> {
// ----- // -----
// CHECK-LABEL: func @const_int_widening // CHECK-LABEL: func @const_bool_widening
func @const_int_widening() -> tensor<i64> { func @const_bool_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64> // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
%cst = mhlo.constant dense<42> : tensor<i32> %cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64> %0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
@ -156,6 +167,17 @@ func @const_int_widening() -> tensor<i64> {
// ----- // -----
// CHECK-LABEL: func @const_int_widening
func @const_int_widening() -> tensor<2xi32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0, 1]> : tensor<2xi32>
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2xi32>
}
// -----
// CHECK-LABEL: func @const_negative_int_widening // CHECK-LABEL: func @const_negative_int_widening
func @const_negative_int_widening() -> tensor<i64> { func @const_negative_int_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64> // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>