Fix failing segment_reduction_ops_mlir_bridge_test
By adding support for complex types to GetScalarOfType and using appropriate choice of limits for initial values in the unsorted segment reduction ops. PiperOrigin-RevId: 327061577
This commit is contained in:
parent
a434b7e4ee
commit
9a232c7012
|
@ -65,9 +65,24 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
|
|||
|
||||
// Returns DenseElementsAttr of rank zero with the given element type and the
|
||||
// value.
|
||||
// Requires `ty` to be either FloatType of IntegerType.
|
||||
// Requires `ty` to be either FloatType, IntegerType, or ComplexType.
|
||||
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
|
||||
|
||||
// Enum type used to specify scalar argument to GetScalarLimitOfType.
|
||||
enum ScalarLimit {
|
||||
kLowest, // The scalar corresponding to numeric_limits<T>::lowest.
|
||||
kInfinityLowest, // Like kMax, but returns -infinity where available.
|
||||
kMax, // The scalar corresponding to numeric_limits<T>::max.
|
||||
kInfinityMax, // Like kMax, but returns infinity where available.
|
||||
};
|
||||
|
||||
// Returns a scalar limit value for the given type.
|
||||
//
|
||||
// The argument 'limit' describes which scalar value to return.
|
||||
//
|
||||
// Requires `ty` to be either FloatType or IntegerType.
|
||||
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -60,10 +60,76 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
|||
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||
APFloat value(float_ty.getFloatSemantics(), raw_value);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
}
|
||||
auto int_ty = ty.cast<IntegerType>();
|
||||
} else if (auto int_ty = ty.dyn_cast<IntegerType>()) {
|
||||
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
} else if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
|
||||
Type complex_element_ty = complex_ty.getElementType();
|
||||
if (complex_element_ty.isF32()) {
|
||||
return DenseElementsAttr::get(
|
||||
scalar_ty, static_cast<std::complex<float>>(raw_value));
|
||||
} else if (complex_element_ty.isF64()) {
|
||||
return DenseElementsAttr::get(
|
||||
scalar_ty, static_cast<std::complex<double>>(raw_value));
|
||||
}
|
||||
}
|
||||
llvm_unreachable("unsupported type");
|
||||
}
|
||||
|
||||
static APFloat GetScalarLimitOfFloatType(FloatType float_ty,
|
||||
ScalarLimit limit) {
|
||||
auto &semantics = float_ty.getFloatSemantics();
|
||||
switch (limit) {
|
||||
case kLowest:
|
||||
return APFloat::getLargest(semantics, /*negative=*/true);
|
||||
case kInfinityLowest:
|
||||
return APFloat::getInf(semantics, /*negative=*/true);
|
||||
case kMax:
|
||||
return APFloat::getLargest(semantics, /*negative=*/false);
|
||||
case kInfinityMax:
|
||||
return APFloat::getInf(semantics, /*negative=*/false);
|
||||
}
|
||||
llvm_unreachable("invalid limit");
|
||||
}
|
||||
|
||||
// Returns a scalar value for the given integer type.
|
||||
//
|
||||
// The argument 'scalar' describes which scalar value to return. `integer_value`
|
||||
// is used to specify the integer value for kInteger. For any other scalar,
|
||||
// integer_value is ignored.
|
||||
static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty,
|
||||
ScalarLimit limit) {
|
||||
unsigned width = integer_ty.getWidth();
|
||||
switch (limit) {
|
||||
case kLowest:
|
||||
case kInfinityLowest:
|
||||
if (integer_ty.isUnsigned()) {
|
||||
return APInt::getMinValue(width);
|
||||
} else {
|
||||
return APInt::getSignedMinValue(width);
|
||||
}
|
||||
|
||||
case kMax:
|
||||
case kInfinityMax:
|
||||
if (integer_ty.isUnsigned()) {
|
||||
return APInt::getMaxValue(width);
|
||||
} else {
|
||||
return APInt::getSignedMaxValue(width);
|
||||
}
|
||||
}
|
||||
llvm_unreachable("invalid limit");
|
||||
}
|
||||
|
||||
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
|
||||
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||
return DenseElementsAttr::get(scalar_ty,
|
||||
GetScalarLimitOfFloatType(float_ty, limit));
|
||||
} else if (auto integer_ty = ty.dyn_cast<IntegerType>()) {
|
||||
return DenseElementsAttr::get(
|
||||
scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit));
|
||||
}
|
||||
llvm_unreachable("unsupported type");
|
||||
}
|
||||
|
||||
} // namespace hlo
|
||||
|
|
Loading…
Reference in New Issue