[mlir][hlo] Fix lowering of NE comparison. It should return true if either side is NaN

PiperOrigin-RevId: 346988987
This commit is contained in:
Benjamin Kramer 2020-12-11 06:45:22 -08:00 committed by TensorFlow MLIR Team
parent ab6ee11813
commit 9930c20c31
2 changed files with 15 additions and 1 deletions

View File

@ -172,7 +172,7 @@ inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
StringRef comparison_direction) {
return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
.Case("EQ", CmpFPredicate::OEQ)
.Case("NE", CmpFPredicate::ONE)
.Case("NE", CmpFPredicate::UNE)
.Case("GE", CmpFPredicate::OGE)
.Case("GT", CmpFPredicate::OGT)
.Case("LE", CmpFPredicate::OLE)

View File

@ -232,6 +232,20 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
// -----
// CHECK-LABEL: func @float_cmp_ne
func @float_cmp_ne(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = cmpf "une", %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {