[mlir][hlo] Fix lowering of NE comparison. It should return true if either side is NaN
PiperOrigin-RevId: 346988987
This commit is contained in:
parent
ab6ee11813
commit
9930c20c31
|
@ -172,7 +172,7 @@ inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
|
|||
StringRef comparison_direction) {
|
||||
return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
|
||||
.Case("EQ", CmpFPredicate::OEQ)
|
||||
.Case("NE", CmpFPredicate::ONE)
|
||||
.Case("NE", CmpFPredicate::UNE)
|
||||
.Case("GE", CmpFPredicate::OGE)
|
||||
.Case("GT", CmpFPredicate::OGT)
|
||||
.Case("LE", CmpFPredicate::OLE)
|
||||
|
|
|
@ -232,6 +232,20 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_cmp_ne
|
||||
func @float_cmp_ne(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = cmpf "une", %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @int_cmp
|
||||
func @int_cmp(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
|
||||
|
|
Loading…
Reference in New Issue