[mlir][hlo] Make min/max always propagate NaNs
This is the right behavior for TF and JAX and matches what TF does on GPU. It doesn't match TF on CPU, but that's really a TF bug. PiperOrigin-RevId: 353628258
This commit is contained in:
		
							parent
							
								
									6af4bccfde
								
							
						
					
					
						commit
						b1438eebcb
					
				|  | @ -437,20 +437,6 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc, | |||
|       loc, result_types, args, b); | ||||
| } | ||||
| 
 | ||||
| inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc, | ||||
|                                     OpBuilder* b) { | ||||
|   Type element_type = getElementTypeOrSelf(args.front().getType()); | ||||
|   if (auto float_type = element_type.dyn_cast<FloatType>()) { | ||||
|     Value isnan = | ||||
|         b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[1]); | ||||
| 
 | ||||
|     auto nan_apfloat = APFloat::getQNaN(float_type.getFloatSemantics()); | ||||
|     Value nan = b->create<mlir::ConstantFloatOp>(loc, nan_apfloat, float_type); | ||||
|     v = b->create<mlir::SelectOp>(loc, isnan, nan, v); | ||||
|   } | ||||
|   return v; | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>( | ||||
|     Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, | ||||
|  | @ -478,13 +464,10 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc, | |||
|                                                   ArrayRef<Type> result_types, | ||||
|                                                   ArrayRef<Value> args, | ||||
|                                                   OpBuilder* b) { | ||||
|   return LhloAlwaysPropagateNaN( | ||||
|       CompareSelectOpToStdScalarOp< | ||||
|           IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, | ||||
|           ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", | ||||
|                                                            result_types, args, | ||||
|                                                            b), | ||||
|       args, loc, b); | ||||
|   return CompareSelectOpToStdScalarOp< | ||||
|       IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, | ||||
|       ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types, | ||||
|                                                        args, b); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
|  | @ -492,13 +475,10 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc, | |||
|                                                   ArrayRef<Type> result_types, | ||||
|                                                   ArrayRef<Value> args, | ||||
|                                                   OpBuilder* b) { | ||||
|   return LhloAlwaysPropagateNaN( | ||||
|       CompareSelectOpToStdScalarOp< | ||||
|           IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, | ||||
|           ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", | ||||
|                                                            result_types, args, | ||||
|                                                            b), | ||||
|       args, loc, b); | ||||
|   return CompareSelectOpToStdScalarOp< | ||||
|       IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, | ||||
|       ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types, | ||||
|                                                        args, b); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
|  |  | |||
|  | @ -540,10 +540,7 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { | |||
| // CHECK: linalg.generic | ||||
| // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32): | ||||
| // CHECK-NEXT:   %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
| // CHECK-NEXT:   %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
| // CHECK-NEXT:   %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
| // CHECK-NEXT:   %[[NAN:.*]] = constant 0x7FC00000 : f32 | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32 | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
| // CHECK-NEXT:   linalg.yield %[[RESULT]] : f32 | ||||
| 
 | ||||
| // ----- | ||||
|  | @ -953,14 +950,10 @@ func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>) | |||
|   // CHECK: %[[INIT:.*]] = linalg.init_tensor | ||||
|   // CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>) | ||||
|   // CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32): | ||||
|   // CHECK:   cmpf olt | ||||
|   // CHECK:   select | ||||
|   // CHECK:   cmpf uno | ||||
|   // CHECK:   select | ||||
|   // CHECK:   cmpf ogt | ||||
|   // CHECK:   select | ||||
|   // CHECK:   cmpf uno | ||||
|   // CHECK:   %[[MAX_X2_LB:.*]] = select | ||||
|   // CHECK:   %[[LT_X_UB:.*]] = cmpf olt, %[[SCALAR_X]], %[[SCALAR_UB]] | ||||
|   // CHECK:   %[[X2:.*]] = select %[[LT_X_UB]], %[[SCALAR_X]], %[[SCALAR_UB]] | ||||
|   // CHECK:   %[[GT_X2_LB:.*]] = cmpf ogt, %[[X2]], %[[SCALAR_LB]] | ||||
|   // CHECK:   %[[MAX_X2_LB:.*]] = select %[[GT_X2_LB]], %[[X2]], %[[SCALAR_LB]] | ||||
|   // CHECK:   linalg.yield %[[MAX_X2_LB]] | ||||
|   // CHECK: } -> tensor<4xf32> | ||||
|   // CHECK: return %[[RESULT]] : tensor<4xf32> | ||||
|  |  | |||
|  | @ -4,7 +4,6 @@ | |||
| // CHECK-LABEL: func @min_op | ||||
| func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, | ||||
|              %result: memref<4x3x2x1xf32>) -> () { | ||||
|   // CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32 | ||||
|   // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 4 { | ||||
|   // CHECK-NEXT:   affine.for %[[J:.*]] = 0 to 3 { | ||||
|   // CHECK-NEXT:     affine.for %[[K:.*]] = 0 to 2 { | ||||
|  | @ -13,9 +12,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, | |||
|   // CHECK-NEXT:         %[[RHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> | ||||
|   // CHECK-NEXT:         %[[MIN_PREDICATE:.*]] = cmpf olt, %[[LHS]], %[[RHS]] : f32 | ||||
|   // CHECK-NEXT:         %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32 | ||||
|   // CHECK-NEXT:         %[[ISNAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32 | ||||
|   // CHECK-NEXT:         %[[MIN_NONAN:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32 | ||||
|   // CHECK-NEXT:         affine.store %[[MIN_NONAN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> | ||||
|   // CHECK-NEXT:         affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> | ||||
|   // CHECK:      return | ||||
|   "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : | ||||
|       (memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> () | ||||
|  | @ -72,11 +69,8 @@ func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, | |||
| // CHECK-LABEL: func @float_max_op | ||||
| func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, | ||||
|                    %result: memref<7xf32>) -> () { | ||||
|   // CHECK: %[[NAN:.*]] = constant 0x7FC00000 : f32 | ||||
|   // CHECK: %[[CMP:.*]] = cmpf ogt, %[[LHS_IN:.*]], %[[RHS_IN:.*]] : f32 | ||||
|   // CHECK: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
|   // CHECK: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
|   // CHECK: select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32 | ||||
|   // CHECK: %[[CHECK:.*]] = cmpf ogt, %[[ONE:.*]], %[[TWO:.*]] : f32 | ||||
|   // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 | ||||
|   "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} | ||||
|       : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () | ||||
|   return | ||||
|  | @ -96,11 +90,8 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, | |||
| // CHECK-LABEL: func @float_min_op | ||||
| func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, | ||||
|                    %result: memref<7xf32>) -> () { | ||||
|   // CHECK: %[[NAN:.*]] = constant 0x7FC00000 : f32 | ||||
|   // CHECK: %[[CMP:.*]] = cmpf olt, %[[LHS_IN:.*]], %[[RHS_IN:.*]] : f32 | ||||
|   // CHECK: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
|   // CHECK: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
|   // CHECK: select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32 | ||||
|   // CHECK: %[[CHECK:.*]] = cmpf olt, %[[ONE:.*]], %[[TWO:.*]] : f32 | ||||
|   // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 | ||||
|   "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} | ||||
|       : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () | ||||
|   return | ||||
|  |  | |||
|  | @ -70,10 +70,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, | |||
| // CHECK: linalg.generic | ||||
| // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32): | ||||
| // CHECK-NEXT:   %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
| // CHECK-NEXT:   %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
| // CHECK-NEXT:   %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
| // CHECK-NEXT:   %[[NAN:.*]] = constant 0x7FC00000 : f32 | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32 | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
| // CHECK-NEXT:   linalg.yield %[[RESULT]] : f32 | ||||
| 
 | ||||
| // ----- | ||||
|  | @ -951,9 +948,9 @@ func @reduce_maximum(%arg: memref<100x10xf32>, | |||
| // CHECK-NEXT: store | ||||
| // CHECK-NEXT: load | ||||
| // CHECK-NEXT: load | ||||
| // CHECK: cmpf | ||||
| // CHECK: select | ||||
| // CHECK: store | ||||
| // CHECK-NEXT: cmpf | ||||
| // CHECK-NEXT: select | ||||
| // CHECK-NEXT: store | ||||
| // CHECK-NEXT: load | ||||
| // CHECK-NEXT: linalg.yield | ||||
| // CHECK-NEXT: } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue