Fix failing segment_reduction_ops_mlir_bridge_test

By adding support for complex types to GetScalarOfType and using appropriate
choice of limits for initial values in the unsorted segment reduction ops.

PiperOrigin-RevId: 327061577
This commit is contained in:
Richard Uhler 2020-08-17 11:27:29 -07:00 committed by TensorFlow MLIR Team
parent a434b7e4ee
commit 9a232c7012
2 changed files with 85 additions and 4 deletions

View File

@ -65,9 +65,24 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
// Returns DenseElementsAttr of rank zero with the given element type and the
// value.
// Requires `ty` to be either FloatType of IntegerType.
// Requires `ty` to be either FloatType, IntegerType, or ComplexType.
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
// Enum type used to specify scalar argument to GetScalarLimitOfType.
enum ScalarLimit {
kLowest, // The scalar corresponding to numeric_limits<T>::lowest.
kInfinityLowest, // Like kMax, but returns -infinity where available.
kMax, // The scalar corresponding to numeric_limits<T>::max.
kInfinityMax, // Like kMax, but returns infinity where available.
};
// Returns a scalar limit value for the given type.
//
// The argument 'limit' describes which scalar value to return.
//
// Requires `ty` to be either FloatType or IntegerType.
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
} // namespace hlo
} // namespace mlir

View File

@ -60,10 +60,76 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
if (auto float_ty = ty.dyn_cast<FloatType>()) {
APFloat value(float_ty.getFloatSemantics(), raw_value);
return DenseElementsAttr::get(scalar_ty, value);
} else if (auto int_ty = ty.dyn_cast<IntegerType>()) {
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
return DenseElementsAttr::get(scalar_ty, value);
} else if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
Type complex_element_ty = complex_ty.getElementType();
if (complex_element_ty.isF32()) {
return DenseElementsAttr::get(
scalar_ty, static_cast<std::complex<float>>(raw_value));
} else if (complex_element_ty.isF64()) {
return DenseElementsAttr::get(
scalar_ty, static_cast<std::complex<double>>(raw_value));
}
}
auto int_ty = ty.cast<IntegerType>();
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
return DenseElementsAttr::get(scalar_ty, value);
llvm_unreachable("unsupported type");
}
static APFloat GetScalarLimitOfFloatType(FloatType float_ty,
ScalarLimit limit) {
auto &semantics = float_ty.getFloatSemantics();
switch (limit) {
case kLowest:
return APFloat::getLargest(semantics, /*negative=*/true);
case kInfinityLowest:
return APFloat::getInf(semantics, /*negative=*/true);
case kMax:
return APFloat::getLargest(semantics, /*negative=*/false);
case kInfinityMax:
return APFloat::getInf(semantics, /*negative=*/false);
}
llvm_unreachable("invalid limit");
}
// Returns a scalar value for the given integer type.
//
// The argument 'scalar' describes which scalar value to return. `integer_value`
// is used to specify the integer value for kInteger. For any other scalar,
// integer_value is ignored.
static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty,
ScalarLimit limit) {
unsigned width = integer_ty.getWidth();
switch (limit) {
case kLowest:
case kInfinityLowest:
if (integer_ty.isUnsigned()) {
return APInt::getMinValue(width);
} else {
return APInt::getSignedMinValue(width);
}
case kMax:
case kInfinityMax:
if (integer_ty.isUnsigned()) {
return APInt::getMaxValue(width);
} else {
return APInt::getSignedMaxValue(width);
}
}
llvm_unreachable("invalid limit");
}
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
if (auto float_ty = ty.dyn_cast<FloatType>()) {
return DenseElementsAttr::get(scalar_ty,
GetScalarLimitOfFloatType(float_ty, limit));
} else if (auto integer_ty = ty.dyn_cast<IntegerType>()) {
return DenseElementsAttr::get(
scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit));
}
llvm_unreachable("unsupported type");
}
} // namespace hlo