[mlir][hlo] Make min/max always propagate NaNs
This is the right behavior for TF and JAX and matches what TF does on GPU. It doesn't match TF on CPU, but that's really a TF bug. PiperOrigin-RevId: 353657779
This commit is contained in:
		
							parent
							
								
									b1438eebcb
								
							
						
					
					
						commit
						f6b24a6d54
					
				| 
						 | 
					@ -437,6 +437,23 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
 | 
				
			||||||
      loc, result_types, args, b);
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc,
 | 
				
			||||||
 | 
					                                    OpBuilder* b) {
 | 
				
			||||||
 | 
					  Type element_type = getElementTypeOrSelf(args.front().getType());
 | 
				
			||||||
 | 
					  if (auto float_type = element_type.dyn_cast<FloatType>()) {
 | 
				
			||||||
 | 
					    Value isnan =
 | 
				
			||||||
 | 
					        b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[1]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto nan_apfloat = APFloat::getQNaN(float_type.getFloatSemantics());
 | 
				
			||||||
 | 
					    Value nan = b->create<mlir::ConstantFloatOp>(loc, nan_apfloat, float_type);
 | 
				
			||||||
 | 
					    if (VectorType vec_type = args[0].getType().dyn_cast<VectorType>()) {
 | 
				
			||||||
 | 
					      nan = b->create<::mlir::SplatOp>(loc, vec_type, nan);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    v = b->create<mlir::SelectOp>(loc, isnan, nan, v);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return v;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
 | 
					inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
 | 
				
			||||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
| 
						 | 
					@ -464,10 +481,13 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
 | 
				
			||||||
                                                  ArrayRef<Type> result_types,
 | 
					                                                  ArrayRef<Type> result_types,
 | 
				
			||||||
                                                  ArrayRef<Value> args,
 | 
					                                                  ArrayRef<Value> args,
 | 
				
			||||||
                                                  OpBuilder* b) {
 | 
					                                                  OpBuilder* b) {
 | 
				
			||||||
  return CompareSelectOpToStdScalarOp<
 | 
					  return LhloAlwaysPropagateNaN(
 | 
				
			||||||
 | 
					      CompareSelectOpToStdScalarOp<
 | 
				
			||||||
          IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
 | 
					          IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
 | 
				
			||||||
      ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types,
 | 
					          ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
 | 
				
			||||||
                                                       args, b);
 | 
					                                                           result_types, args,
 | 
				
			||||||
 | 
					                                                           b),
 | 
				
			||||||
 | 
					      args, loc, b);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
| 
						 | 
					@ -475,10 +495,13 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
 | 
				
			||||||
                                                  ArrayRef<Type> result_types,
 | 
					                                                  ArrayRef<Type> result_types,
 | 
				
			||||||
                                                  ArrayRef<Value> args,
 | 
					                                                  ArrayRef<Value> args,
 | 
				
			||||||
                                                  OpBuilder* b) {
 | 
					                                                  OpBuilder* b) {
 | 
				
			||||||
  return CompareSelectOpToStdScalarOp<
 | 
					  return LhloAlwaysPropagateNaN(
 | 
				
			||||||
 | 
					      CompareSelectOpToStdScalarOp<
 | 
				
			||||||
          IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
 | 
					          IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
 | 
				
			||||||
      ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types,
 | 
					          ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
 | 
				
			||||||
                                                       args, b);
 | 
					                                                           result_types, args,
 | 
				
			||||||
 | 
					                                                           b),
 | 
				
			||||||
 | 
					      args, loc, b);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -540,7 +540,10 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
				
			||||||
// CHECK: linalg.generic
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32):
 | 
					// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32):
 | 
				
			||||||
// CHECK-NEXT:   %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
					// CHECK-NEXT:   %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
				
			||||||
// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
					// CHECK-NEXT:   %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
				
			||||||
 | 
					// CHECK-NEXT:   %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
				
			||||||
 | 
					// CHECK-NEXT:   %[[NAN:.*]] = constant 0x7FC00000 : f32
 | 
				
			||||||
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32
 | 
				
			||||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 | 
					// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
| 
						 | 
					@ -950,10 +953,14 @@ func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>)
 | 
				
			||||||
  // CHECK: %[[INIT:.*]] = linalg.init_tensor
 | 
					  // CHECK: %[[INIT:.*]] = linalg.init_tensor
 | 
				
			||||||
  // CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>)
 | 
					  // CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>)
 | 
				
			||||||
  // CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32):
 | 
					  // CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32):
 | 
				
			||||||
  // CHECK:   %[[LT_X_UB:.*]] = cmpf olt, %[[SCALAR_X]], %[[SCALAR_UB]]
 | 
					  // CHECK:   cmpf olt
 | 
				
			||||||
  // CHECK:   %[[X2:.*]] = select %[[LT_X_UB]], %[[SCALAR_X]], %[[SCALAR_UB]]
 | 
					  // CHECK:   select
 | 
				
			||||||
  // CHECK:   %[[GT_X2_LB:.*]] = cmpf ogt, %[[X2]], %[[SCALAR_LB]]
 | 
					  // CHECK:   cmpf uno
 | 
				
			||||||
  // CHECK:   %[[MAX_X2_LB:.*]] = select %[[GT_X2_LB]], %[[X2]], %[[SCALAR_LB]]
 | 
					  // CHECK:   select
 | 
				
			||||||
 | 
					  // CHECK:   cmpf ogt
 | 
				
			||||||
 | 
					  // CHECK:   select
 | 
				
			||||||
 | 
					  // CHECK:   cmpf uno
 | 
				
			||||||
 | 
					  // CHECK:   %[[MAX_X2_LB:.*]] = select
 | 
				
			||||||
  // CHECK:   linalg.yield %[[MAX_X2_LB]]
 | 
					  // CHECK:   linalg.yield %[[MAX_X2_LB]]
 | 
				
			||||||
  // CHECK: } -> tensor<4xf32>
 | 
					  // CHECK: } -> tensor<4xf32>
 | 
				
			||||||
  // CHECK: return %[[RESULT]] : tensor<4xf32>
 | 
					  // CHECK: return %[[RESULT]] : tensor<4xf32>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,6 +4,7 @@
 | 
				
			||||||
// CHECK-LABEL: func @min_op
 | 
					// CHECK-LABEL: func @min_op
 | 
				
			||||||
func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
 | 
					func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
 | 
				
			||||||
             %result: memref<4x3x2x1xf32>) -> () {
 | 
					             %result: memref<4x3x2x1xf32>) -> () {
 | 
				
			||||||
 | 
					  // CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
 | 
				
			||||||
  // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 4 {
 | 
					  // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 4 {
 | 
				
			||||||
  // CHECK-NEXT:   affine.for %[[J:.*]] = 0 to 3 {
 | 
					  // CHECK-NEXT:   affine.for %[[J:.*]] = 0 to 3 {
 | 
				
			||||||
  // CHECK-NEXT:     affine.for %[[K:.*]] = 0 to 2 {
 | 
					  // CHECK-NEXT:     affine.for %[[K:.*]] = 0 to 2 {
 | 
				
			||||||
| 
						 | 
					@ -12,7 +13,9 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
 | 
				
			||||||
  // CHECK-NEXT:         %[[RHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
 | 
					  // CHECK-NEXT:         %[[RHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
 | 
				
			||||||
  // CHECK-NEXT:         %[[MIN_PREDICATE:.*]] = cmpf olt, %[[LHS]], %[[RHS]] : f32
 | 
					  // CHECK-NEXT:         %[[MIN_PREDICATE:.*]] = cmpf olt, %[[LHS]], %[[RHS]] : f32
 | 
				
			||||||
  // CHECK-NEXT:         %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32
 | 
					  // CHECK-NEXT:         %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32
 | 
				
			||||||
  // CHECK-NEXT:         affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
 | 
					  // CHECK-NEXT:         %[[ISNAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32
 | 
				
			||||||
 | 
					  // CHECK-NEXT:         %[[MIN_NONAN:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32
 | 
				
			||||||
 | 
					  // CHECK-NEXT:         affine.store %[[MIN_NONAN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
 | 
				
			||||||
  // CHECK:      return
 | 
					  // CHECK:      return
 | 
				
			||||||
  "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
 | 
					  "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
 | 
				
			||||||
      (memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> ()
 | 
					      (memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> ()
 | 
				
			||||||
| 
						 | 
					@ -69,8 +72,11 @@ func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
 | 
				
			||||||
// CHECK-LABEL: func @float_max_op
 | 
					// CHECK-LABEL: func @float_max_op
 | 
				
			||||||
func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
 | 
					func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
 | 
				
			||||||
                   %result: memref<7xf32>) -> () {
 | 
					                   %result: memref<7xf32>) -> () {
 | 
				
			||||||
  // CHECK: %[[CHECK:.*]] = cmpf ogt, %[[ONE:.*]], %[[TWO:.*]] : f32
 | 
					  // CHECK: %[[NAN:.*]] = constant 0x7FC00000 : f32
 | 
				
			||||||
  // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
 | 
					  // CHECK: %[[CMP:.*]] = cmpf ogt, %[[LHS_IN:.*]], %[[RHS_IN:.*]] : f32
 | 
				
			||||||
 | 
					  // CHECK: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
				
			||||||
 | 
					  // CHECK: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
				
			||||||
 | 
					  // CHECK: select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32
 | 
				
			||||||
  "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
 | 
					  "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
 | 
				
			||||||
      : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
 | 
					      : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
 | 
				
			||||||
  return
 | 
					  return
 | 
				
			||||||
| 
						 | 
					@ -90,8 +96,11 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
 | 
				
			||||||
// CHECK-LABEL: func @float_min_op
 | 
					// CHECK-LABEL: func @float_min_op
 | 
				
			||||||
func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
 | 
					func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
 | 
				
			||||||
                   %result: memref<7xf32>) -> () {
 | 
					                   %result: memref<7xf32>) -> () {
 | 
				
			||||||
  // CHECK: %[[CHECK:.*]] = cmpf olt, %[[ONE:.*]], %[[TWO:.*]] : f32
 | 
					  // CHECK: %[[NAN:.*]] = constant 0x7FC00000 : f32
 | 
				
			||||||
  // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
 | 
					  // CHECK: %[[CMP:.*]] = cmpf olt, %[[LHS_IN:.*]], %[[RHS_IN:.*]] : f32
 | 
				
			||||||
 | 
					  // CHECK: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
				
			||||||
 | 
					  // CHECK: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
				
			||||||
 | 
					  // CHECK: select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32
 | 
				
			||||||
  "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
 | 
					  "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
 | 
				
			||||||
      : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
 | 
					      : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
 | 
				
			||||||
  return
 | 
					  return
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -70,7 +70,10 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
 | 
				
			||||||
// CHECK: linalg.generic
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
 | 
					// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
 | 
				
			||||||
// CHECK-NEXT:   %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
					// CHECK-NEXT:   %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
				
			||||||
// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
					// CHECK-NEXT:   %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
				
			||||||
 | 
					// CHECK-NEXT:   %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
				
			||||||
 | 
					// CHECK-NEXT:   %[[NAN:.*]] = constant 0x7FC00000 : f32
 | 
				
			||||||
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32
 | 
				
			||||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 | 
					// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
| 
						 | 
					@ -948,9 +951,9 @@ func @reduce_maximum(%arg: memref<100x10xf32>,
 | 
				
			||||||
// CHECK-NEXT: store
 | 
					// CHECK-NEXT: store
 | 
				
			||||||
// CHECK-NEXT: load
 | 
					// CHECK-NEXT: load
 | 
				
			||||||
// CHECK-NEXT: load
 | 
					// CHECK-NEXT: load
 | 
				
			||||||
// CHECK-NEXT: cmpf
 | 
					// CHECK: cmpf
 | 
				
			||||||
// CHECK-NEXT: select
 | 
					// CHECK: select
 | 
				
			||||||
// CHECK-NEXT: store
 | 
					// CHECK: store
 | 
				
			||||||
// CHECK-NEXT: load
 | 
					// CHECK-NEXT: load
 | 
				
			||||||
// CHECK-NEXT: linalg.yield
 | 
					// CHECK-NEXT: linalg.yield
 | 
				
			||||||
// CHECK-NEXT: }
 | 
					// CHECK-NEXT: }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue