Fix failing segment_reduction_ops_mlir_bridge_test
By adding support for complex types to GetScalarOfType and using appropriate choice of limits for initial values in the unsorted segment reduction ops. PiperOrigin-RevId: 327061577
This commit is contained in:
parent
a434b7e4ee
commit
9a232c7012
|
@ -65,9 +65,24 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
|
||||||
|
|
||||||
// Returns DenseElementsAttr of rank zero with the given element type and the
|
// Returns DenseElementsAttr of rank zero with the given element type and the
|
||||||
// value.
|
// value.
|
||||||
// Requires `ty` to be either FloatType of IntegerType.
|
// Requires `ty` to be either FloatType, IntegerType, or ComplexType.
|
||||||
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
|
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
|
||||||
|
|
||||||
|
// Enum type used to specify scalar argument to GetScalarLimitOfType.
|
||||||
|
enum ScalarLimit {
|
||||||
|
kLowest, // The scalar corresponding to numeric_limits<T>::lowest.
|
||||||
|
kInfinityLowest, // Like kMax, but returns -infinity where available.
|
||||||
|
kMax, // The scalar corresponding to numeric_limits<T>::max.
|
||||||
|
kInfinityMax, // Like kMax, but returns infinity where available.
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns a scalar limit value for the given type.
|
||||||
|
//
|
||||||
|
// The argument 'limit' describes which scalar value to return.
|
||||||
|
//
|
||||||
|
// Requires `ty` to be either FloatType or IntegerType.
|
||||||
|
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
|
||||||
|
|
||||||
} // namespace hlo
|
} // namespace hlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -60,10 +60,76 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
||||||
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||||
APFloat value(float_ty.getFloatSemantics(), raw_value);
|
APFloat value(float_ty.getFloatSemantics(), raw_value);
|
||||||
return DenseElementsAttr::get(scalar_ty, value);
|
return DenseElementsAttr::get(scalar_ty, value);
|
||||||
}
|
} else if (auto int_ty = ty.dyn_cast<IntegerType>()) {
|
||||||
auto int_ty = ty.cast<IntegerType>();
|
|
||||||
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||||||
return DenseElementsAttr::get(scalar_ty, value);
|
return DenseElementsAttr::get(scalar_ty, value);
|
||||||
|
} else if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
|
||||||
|
Type complex_element_ty = complex_ty.getElementType();
|
||||||
|
if (complex_element_ty.isF32()) {
|
||||||
|
return DenseElementsAttr::get(
|
||||||
|
scalar_ty, static_cast<std::complex<float>>(raw_value));
|
||||||
|
} else if (complex_element_ty.isF64()) {
|
||||||
|
return DenseElementsAttr::get(
|
||||||
|
scalar_ty, static_cast<std::complex<double>>(raw_value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llvm_unreachable("unsupported type");
|
||||||
|
}
|
||||||
|
|
||||||
|
static APFloat GetScalarLimitOfFloatType(FloatType float_ty,
|
||||||
|
ScalarLimit limit) {
|
||||||
|
auto &semantics = float_ty.getFloatSemantics();
|
||||||
|
switch (limit) {
|
||||||
|
case kLowest:
|
||||||
|
return APFloat::getLargest(semantics, /*negative=*/true);
|
||||||
|
case kInfinityLowest:
|
||||||
|
return APFloat::getInf(semantics, /*negative=*/true);
|
||||||
|
case kMax:
|
||||||
|
return APFloat::getLargest(semantics, /*negative=*/false);
|
||||||
|
case kInfinityMax:
|
||||||
|
return APFloat::getInf(semantics, /*negative=*/false);
|
||||||
|
}
|
||||||
|
llvm_unreachable("invalid limit");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a scalar value for the given integer type.
|
||||||
|
//
|
||||||
|
// The argument 'scalar' describes which scalar value to return. `integer_value`
|
||||||
|
// is used to specify the integer value for kInteger. For any other scalar,
|
||||||
|
// integer_value is ignored.
|
||||||
|
static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty,
|
||||||
|
ScalarLimit limit) {
|
||||||
|
unsigned width = integer_ty.getWidth();
|
||||||
|
switch (limit) {
|
||||||
|
case kLowest:
|
||||||
|
case kInfinityLowest:
|
||||||
|
if (integer_ty.isUnsigned()) {
|
||||||
|
return APInt::getMinValue(width);
|
||||||
|
} else {
|
||||||
|
return APInt::getSignedMinValue(width);
|
||||||
|
}
|
||||||
|
|
||||||
|
case kMax:
|
||||||
|
case kInfinityMax:
|
||||||
|
if (integer_ty.isUnsigned()) {
|
||||||
|
return APInt::getMaxValue(width);
|
||||||
|
} else {
|
||||||
|
return APInt::getSignedMaxValue(width);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llvm_unreachable("invalid limit");
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
|
||||||
|
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||||
|
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||||
|
return DenseElementsAttr::get(scalar_ty,
|
||||||
|
GetScalarLimitOfFloatType(float_ty, limit));
|
||||||
|
} else if (auto integer_ty = ty.dyn_cast<IntegerType>()) {
|
||||||
|
return DenseElementsAttr::get(
|
||||||
|
scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit));
|
||||||
|
}
|
||||||
|
llvm_unreachable("unsupported type");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace hlo
|
} // namespace hlo
|
||||||
|
|
Loading…
Reference in New Issue