[mlir][hlo] Fix lowering of NE comparison. It should return true if either side is NaN
PiperOrigin-RevId: 346988987
This commit is contained in:
parent
ab6ee11813
commit
9930c20c31
|
@ -172,7 +172,7 @@ inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
|
||||||
StringRef comparison_direction) {
|
StringRef comparison_direction) {
|
||||||
return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
|
return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
|
||||||
.Case("EQ", CmpFPredicate::OEQ)
|
.Case("EQ", CmpFPredicate::OEQ)
|
||||||
.Case("NE", CmpFPredicate::ONE)
|
.Case("NE", CmpFPredicate::UNE)
|
||||||
.Case("GE", CmpFPredicate::OGE)
|
.Case("GE", CmpFPredicate::OGE)
|
||||||
.Case("GT", CmpFPredicate::OGT)
|
.Case("GT", CmpFPredicate::OGT)
|
||||||
.Case("LE", CmpFPredicate::OLE)
|
.Case("LE", CmpFPredicate::OLE)
|
||||||
|
|
|
@ -232,6 +232,20 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @float_cmp_ne
|
||||||
|
func @float_cmp_ne(%lhs: tensor<2x2xf32>,
|
||||||
|
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
|
||||||
|
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"}
|
||||||
|
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||||
|
return %0 : tensor<2x2xi1>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = cmpf "une", %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @int_cmp
|
// CHECK-LABEL: func @int_cmp
|
||||||
func @int_cmp(%lhs: tensor<2x2xi32>,
|
func @int_cmp(%lhs: tensor<2x2xi32>,
|
||||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
|
||||||
|
|
Loading…
Reference in New Issue