From b1438eebcb4fd688f06e0803ff039776056c8d1e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 25 Jan 2021 05:42:24 -0800 Subject: [PATCH] [mlir][hlo] Make min/max always propagate NaNs This is the right behavior for TF and JAX and matches what TF does on GPU. It doesn't match TF on CPU, but that's really a TF bug. PiperOrigin-RevId: 353628258 --- .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 36 +++++-------------- tests/hlo-legalize-to-linalg.mlir | 17 +++------ tests/lhlo-legalize-to-affine.mlir | 19 +++------- tests/lhlo-legalize-to-linalg.mlir | 11 +++--- 4 files changed, 22 insertions(+), 61 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 3820423..9354830 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -437,20 +437,6 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } -inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef args, Location loc, - OpBuilder* b) { - Type element_type = getElementTypeOrSelf(args.front().getType()); - if (auto float_type = element_type.dyn_cast()) { - Value isnan = - b->create(loc, CmpFPredicate::UNO, args[0], args[1]); - - auto nan_apfloat = APFloat::getQNaN(float_type.getFloatSemantics()); - Value nan = b->create(loc, nan_apfloat, float_type); - v = b->create(loc, isnan, nan, v); - } - return v; -} - template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, @@ -478,13 +464,10 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return LhloAlwaysPropagateNaN( - CompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>::map(loc, "GT", - result_types, args, - b), - args, loc, b); + return CompareSelectOpToStdScalarOp< + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "GT", result_types, + args, b); } template <> @@ -492,13 +475,10 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return LhloAlwaysPropagateNaN( - CompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>::map(loc, "LT", - result_types, args, - b), - args, loc, b); + return CompareSelectOpToStdScalarOp< + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "LT", result_types, + args, b); } template <> diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 48332d2..f4e9287 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -540,10 +540,7 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- @@ -953,14 +950,10 @@ func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>) // CHECK: %[[INIT:.*]] = linalg.init_tensor // CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>) // CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32): - // CHECK: cmpf olt - // CHECK: select - // CHECK: cmpf uno - // CHECK: select - // CHECK: cmpf ogt - // CHECK: select - // CHECK: cmpf uno - // CHECK: %[[MAX_X2_LB:.*]] = select + // CHECK: %[[LT_X_UB:.*]] = cmpf olt, %[[SCALAR_X]], %[[SCALAR_UB]] + // CHECK: %[[X2:.*]] = select %[[LT_X_UB]], %[[SCALAR_X]], %[[SCALAR_UB]] + // CHECK: %[[GT_X2_LB:.*]] = cmpf ogt, %[[X2]], %[[SCALAR_LB]] + // CHECK: %[[MAX_X2_LB:.*]] = select %[[GT_X2_LB]], %[[X2]], %[[SCALAR_LB]] // CHECK: linalg.yield %[[MAX_X2_LB]] // CHECK: } -> tensor<4xf32> // CHECK: return %[[RESULT]] : tensor<4xf32> diff --git a/tests/lhlo-legalize-to-affine.mlir b/tests/lhlo-legalize-to-affine.mlir index 35c0bef..ee67ea1 100644 --- a/tests/lhlo-legalize-to-affine.mlir +++ b/tests/lhlo-legalize-to-affine.mlir @@ -4,7 +4,6 @@ // CHECK-LABEL: func @min_op func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, %result: memref<4x3x2x1xf32>) -> () { - // CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32 // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 4 { // CHECK-NEXT: affine.for %[[J:.*]] = 0 to 3 { // CHECK-NEXT: affine.for %[[K:.*]] = 0 to 2 { @@ -13,9 +12,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, // CHECK-NEXT: %[[RHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> // CHECK-NEXT: %[[MIN_PREDICATE:.*]] = cmpf olt, %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32 - // CHECK-NEXT: %[[ISNAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32 - // CHECK-NEXT: %[[MIN_NONAN:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32 - // CHECK-NEXT: affine.store %[[MIN_NONAN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> + // CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> // CHECK: return "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> () @@ -72,11 +69,8 @@ func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, // CHECK-LABEL: func @float_max_op func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { - // CHECK: %[[NAN:.*]] = constant 0x7FC00000 : f32 - // CHECK: %[[CMP:.*]] = cmpf ogt, %[[LHS_IN:.*]], %[[RHS_IN:.*]] : f32 - // CHECK: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 - // CHECK: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32 - // CHECK: select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32 + // CHECK: %[[CHECK:.*]] = cmpf ogt, %[[ONE:.*]], %[[TWO:.*]] : f32 + // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return @@ -96,11 +90,8 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, // CHECK-LABEL: func @float_min_op func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { - // CHECK: %[[NAN:.*]] = constant 0x7FC00000 : f32 - // CHECK: %[[CMP:.*]] = cmpf olt, %[[LHS_IN:.*]], %[[RHS_IN:.*]] : f32 - // CHECK: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 - // CHECK: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32 - // CHECK: select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32 + // CHECK: %[[CHECK:.*]] = cmpf olt, %[[ONE:.*]], %[[TWO:.*]] : f32 + // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 36bbdb7..dfcc3c7 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -70,10 +70,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32): // CHECK-NEXT: %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- @@ -951,9 +948,9 @@ func @reduce_maximum(%arg: memref<100x10xf32>, // CHECK-NEXT: store // CHECK-NEXT: load // CHECK-NEXT: load -// CHECK: cmpf -// CHECK: select -// CHECK: store +// CHECK-NEXT: cmpf +// CHECK-NEXT: select +// CHECK-NEXT: store // CHECK-NEXT: load // CHECK-NEXT: linalg.yield // CHECK-NEXT: }