Remove fold of `mhlo.compare(%arg0, %arg0)` for floating types.
Two tensors having the same SSA-value isn't sufficient for equality for floating types, as `NaN != NaN`. As written this causes `tf.IsNan` to [miscompile](https://github.com/google/iree/issues/4061). PiperOrigin-RevId: 345730640
This commit is contained in:
parent
9bd1995f90
commit
c33bdcbd03
|
@ -2782,11 +2782,10 @@ OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (!result_ty.hasStaticShape()) return {};
|
if (!result_ty.hasStaticShape()) return {};
|
||||||
|
|
||||||
auto direction = comparison_direction();
|
auto direction = comparison_direction();
|
||||||
if (lhs() == rhs()) {
|
if (lhs() == rhs() && !getElementTypeOrSelf(lhs()).isa<FloatType>()) {
|
||||||
if (direction == "LE" || direction == "EQ" || direction == "GE") {
|
if (direction == "LE" || direction == "EQ" || direction == "GE") {
|
||||||
return DenseIntElementsAttr::get(result_ty, {true});
|
return DenseIntElementsAttr::get(result_ty, {true});
|
||||||
}
|
}
|
||||||
|
|
||||||
return DenseIntElementsAttr::get(result_ty, {false});
|
return DenseIntElementsAttr::get(result_ty, {false});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -674,6 +674,14 @@ func @fold_compare_same_gt(%arg0: tensor<i64>) -> tensor<i1> {
|
||||||
return %0 : tensor<i1>
|
return %0 : tensor<i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Address NaN != NaN.
|
||||||
|
// CHECK-LABEL: dont_fold_compare_same_eq_float
|
||||||
|
func @dont_fold_compare_same_eq_float(%arg0: tensor<f16>) -> tensor<i1> {
|
||||||
|
// CHECK: %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<f16>, tensor<f16>) -> tensor<i1>
|
||||||
|
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<f16>, tensor<f16>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: fold_compare_false_eq
|
// CHECK-LABEL: fold_compare_false_eq
|
||||||
func @fold_compare_false_eq() -> tensor<i1> {
|
func @fold_compare_false_eq() -> tensor<i1> {
|
||||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
|
Loading…
Reference in New Issue