[mlir][hlo] Make min/max always propagate NaNs
This is the right behavior for TF and JAX and matches what TF does on GPU. It doesn't match TF on CPU, but that's really a TF bug. PiperOrigin-RevId: 353657779
This commit is contained in:
parent
b1438eebcb
commit
f6b24a6d54
|
@ -437,6 +437,23 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc,
|
||||||
|
OpBuilder* b) {
|
||||||
|
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||||
|
if (auto float_type = element_type.dyn_cast<FloatType>()) {
|
||||||
|
Value isnan =
|
||||||
|
b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[1]);
|
||||||
|
|
||||||
|
auto nan_apfloat = APFloat::getQNaN(float_type.getFloatSemantics());
|
||||||
|
Value nan = b->create<mlir::ConstantFloatOp>(loc, nan_apfloat, float_type);
|
||||||
|
if (VectorType vec_type = args[0].getType().dyn_cast<VectorType>()) {
|
||||||
|
nan = b->create<::mlir::SplatOp>(loc, vec_type, nan);
|
||||||
|
}
|
||||||
|
v = b->create<mlir::SelectOp>(loc, isnan, nan, v);
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
@ -464,10 +481,13 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return CompareSelectOpToStdScalarOp<
|
return LhloAlwaysPropagateNaN(
|
||||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
CompareSelectOpToStdScalarOp<
|
||||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types,
|
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||||
args, b);
|
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
|
||||||
|
result_types, args,
|
||||||
|
b),
|
||||||
|
args, loc, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -475,10 +495,13 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return CompareSelectOpToStdScalarOp<
|
return LhloAlwaysPropagateNaN(
|
||||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
CompareSelectOpToStdScalarOp<
|
||||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types,
|
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||||
args, b);
|
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
|
||||||
|
result_types, args,
|
||||||
|
b),
|
||||||
|
args, loc, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
|
|
@ -540,7 +540,10 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32):
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32):
|
||||||
// CHECK-NEXT: %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32
|
// CHECK-NEXT: %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
|
// CHECK-NEXT: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
|
// CHECK-NEXT: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
|
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32
|
||||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -950,10 +953,14 @@ func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>)
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor
|
||||||
// CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>)
|
// CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>)
|
||||||
// CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32):
|
// CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32):
|
||||||
// CHECK: %[[LT_X_UB:.*]] = cmpf olt, %[[SCALAR_X]], %[[SCALAR_UB]]
|
// CHECK: cmpf olt
|
||||||
// CHECK: %[[X2:.*]] = select %[[LT_X_UB]], %[[SCALAR_X]], %[[SCALAR_UB]]
|
// CHECK: select
|
||||||
// CHECK: %[[GT_X2_LB:.*]] = cmpf ogt, %[[X2]], %[[SCALAR_LB]]
|
// CHECK: cmpf uno
|
||||||
// CHECK: %[[MAX_X2_LB:.*]] = select %[[GT_X2_LB]], %[[X2]], %[[SCALAR_LB]]
|
// CHECK: select
|
||||||
|
// CHECK: cmpf ogt
|
||||||
|
// CHECK: select
|
||||||
|
// CHECK: cmpf uno
|
||||||
|
// CHECK: %[[MAX_X2_LB:.*]] = select
|
||||||
// CHECK: linalg.yield %[[MAX_X2_LB]]
|
// CHECK: linalg.yield %[[MAX_X2_LB]]
|
||||||
// CHECK: } -> tensor<4xf32>
|
// CHECK: } -> tensor<4xf32>
|
||||||
// CHECK: return %[[RESULT]] : tensor<4xf32>
|
// CHECK: return %[[RESULT]] : tensor<4xf32>
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
// CHECK-LABEL: func @min_op
|
// CHECK-LABEL: func @min_op
|
||||||
func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
|
func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
|
||||||
%result: memref<4x3x2x1xf32>) -> () {
|
%result: memref<4x3x2x1xf32>) -> () {
|
||||||
|
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
|
||||||
// CHECK-NEXT: affine.for %[[I:.*]] = 0 to 4 {
|
// CHECK-NEXT: affine.for %[[I:.*]] = 0 to 4 {
|
||||||
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 3 {
|
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 3 {
|
||||||
// CHECK-NEXT: affine.for %[[K:.*]] = 0 to 2 {
|
// CHECK-NEXT: affine.for %[[K:.*]] = 0 to 2 {
|
||||||
|
@ -12,7 +13,9 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
|
||||||
// CHECK-NEXT: %[[RHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
|
// CHECK-NEXT: %[[RHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
|
||||||
// CHECK-NEXT: %[[MIN_PREDICATE:.*]] = cmpf olt, %[[LHS]], %[[RHS]] : f32
|
// CHECK-NEXT: %[[MIN_PREDICATE:.*]] = cmpf olt, %[[LHS]], %[[RHS]] : f32
|
||||||
// CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32
|
// CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32
|
||||||
// CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
|
// CHECK-NEXT: %[[ISNAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32
|
||||||
|
// CHECK-NEXT: %[[MIN_NONAN:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32
|
||||||
|
// CHECK-NEXT: affine.store %[[MIN_NONAN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
|
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
|
||||||
(memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> ()
|
(memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> ()
|
||||||
|
@ -69,8 +72,11 @@ func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||||
// CHECK-LABEL: func @float_max_op
|
// CHECK-LABEL: func @float_max_op
|
||||||
func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||||
%result: memref<7xf32>) -> () {
|
%result: memref<7xf32>) -> () {
|
||||||
// CHECK: %[[CHECK:.*]] = cmpf ogt, %[[ONE:.*]], %[[TWO:.*]] : f32
|
// CHECK: %[[NAN:.*]] = constant 0x7FC00000 : f32
|
||||||
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
|
// CHECK: %[[CMP:.*]] = cmpf ogt, %[[LHS_IN:.*]], %[[RHS_IN:.*]] : f32
|
||||||
|
// CHECK: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
|
// CHECK: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
|
// CHECK: select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32
|
||||||
"lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
|
"lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
|
||||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||||
return
|
return
|
||||||
|
@ -90,8 +96,11 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||||
// CHECK-LABEL: func @float_min_op
|
// CHECK-LABEL: func @float_min_op
|
||||||
func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||||
%result: memref<7xf32>) -> () {
|
%result: memref<7xf32>) -> () {
|
||||||
// CHECK: %[[CHECK:.*]] = cmpf olt, %[[ONE:.*]], %[[TWO:.*]] : f32
|
// CHECK: %[[NAN:.*]] = constant 0x7FC00000 : f32
|
||||||
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
|
// CHECK: %[[CMP:.*]] = cmpf olt, %[[LHS_IN:.*]], %[[RHS_IN:.*]] : f32
|
||||||
|
// CHECK: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
|
// CHECK: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
|
// CHECK: select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32
|
||||||
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
|
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
|
||||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||||
return
|
return
|
||||||
|
|
|
@ -70,7 +70,10 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
|
||||||
// CHECK-NEXT: %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32
|
// CHECK-NEXT: %[[CMP:.*]] = cmpf olt, %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
|
// CHECK-NEXT: %[[MIN:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
|
// CHECK-NEXT: %[[ISNAN:.*]] = cmpf uno, %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
|
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = select %[[ISNAN]], %[[NAN]], %[[MIN]] : f32
|
||||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -948,9 +951,9 @@ func @reduce_maximum(%arg: memref<100x10xf32>,
|
||||||
// CHECK-NEXT: store
|
// CHECK-NEXT: store
|
||||||
// CHECK-NEXT: load
|
// CHECK-NEXT: load
|
||||||
// CHECK-NEXT: load
|
// CHECK-NEXT: load
|
||||||
// CHECK-NEXT: cmpf
|
// CHECK: cmpf
|
||||||
// CHECK-NEXT: select
|
// CHECK: select
|
||||||
// CHECK-NEXT: store
|
// CHECK: store
|
||||||
// CHECK-NEXT: load
|
// CHECK-NEXT: load
|
||||||
// CHECK-NEXT: linalg.yield
|
// CHECK-NEXT: linalg.yield
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
|
Loading…
Reference in New Issue