diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 92032e4..a311ae7 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -172,7 +172,7 @@ inline Optional getCmpPredicate( StringRef comparison_direction) { return llvm::StringSwitch>(comparison_direction) .Case("EQ", CmpFPredicate::OEQ) - .Case("NE", CmpFPredicate::ONE) + .Case("NE", CmpFPredicate::UNE) .Case("GE", CmpFPredicate::OGE) .Case("GT", CmpFPredicate::OGT) .Case("LE", CmpFPredicate::OLE) diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 62f416f..517dca8 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -232,6 +232,20 @@ func @float_cmp(%lhs: tensor<2x2xf32>, // ----- +// CHECK-LABEL: func @float_cmp_ne +func @float_cmp_ne(%lhs: tensor<2x2xf32>, + %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { + %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"} + : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> + return %0 : tensor<2x2xi1> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = cmpf "une", %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @int_cmp func @int_cmp(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {