diff --git a/include/mlir-hlo/utils/hlo_utils.h b/include/mlir-hlo/utils/hlo_utils.h index 1e335ae..74ea9c9 100644 --- a/include/mlir-hlo/utils/hlo_utils.h +++ b/include/mlir-hlo/utils/hlo_utils.h @@ -65,9 +65,24 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) { // Returns DenseElementsAttr of rank zero with the given element type and the // value. -// Requires `ty` to be either FloatType of IntegerType. +// Requires `ty` to be either FloatType, IntegerType, or ComplexType. DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value); +// Enum type used to specify scalar argument to GetScalarLimitOfType. +enum ScalarLimit { + kLowest, // The scalar corresponding to numeric_limits::lowest. + kInfinityLowest, // Like kMax, but returns -infinity where available. + kMax, // The scalar corresponding to numeric_limits::max. + kInfinityMax, // Like kMax, but returns infinity where available. +}; + +// Returns a scalar limit value for the given type. +// +// The argument 'limit' describes which scalar value to return. +// +// Requires `ty` to be either FloatType or IntegerType. +DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit); + } // namespace hlo } // namespace mlir diff --git a/lib/utils/hlo_utils.cc b/lib/utils/hlo_utils.cc index df2442c..0bbd91e 100644 --- a/lib/utils/hlo_utils.cc +++ b/lib/utils/hlo_utils.cc @@ -60,10 +60,76 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { if (auto float_ty = ty.dyn_cast()) { APFloat value(float_ty.getFloatSemantics(), raw_value); return DenseElementsAttr::get(scalar_ty, value); + } else if (auto int_ty = ty.dyn_cast()) { + APInt value(int_ty.getWidth(), static_cast(raw_value), true); + return DenseElementsAttr::get(scalar_ty, value); + } else if (auto complex_ty = ty.dyn_cast()) { + Type complex_element_ty = complex_ty.getElementType(); + if (complex_element_ty.isF32()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } else if (complex_element_ty.isF64()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } } - auto int_ty = ty.cast(); - APInt value(int_ty.getWidth(), static_cast(raw_value), true); - return DenseElementsAttr::get(scalar_ty, value); + llvm_unreachable("unsupported type"); +} + +static APFloat GetScalarLimitOfFloatType(FloatType float_ty, + ScalarLimit limit) { + auto &semantics = float_ty.getFloatSemantics(); + switch (limit) { + case kLowest: + return APFloat::getLargest(semantics, /*negative=*/true); + case kInfinityLowest: + return APFloat::getInf(semantics, /*negative=*/true); + case kMax: + return APFloat::getLargest(semantics, /*negative=*/false); + case kInfinityMax: + return APFloat::getInf(semantics, /*negative=*/false); + } + llvm_unreachable("invalid limit"); +} + +// Returns a scalar value for the given integer type. +// +// The argument 'scalar' describes which scalar value to return. `integer_value` +// is used to specify the integer value for kInteger. For any other scalar, +// integer_value is ignored. +static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty, + ScalarLimit limit) { + unsigned width = integer_ty.getWidth(); + switch (limit) { + case kLowest: + case kInfinityLowest: + if (integer_ty.isUnsigned()) { + return APInt::getMinValue(width); + } else { + return APInt::getSignedMinValue(width); + } + + case kMax: + case kInfinityMax: + if (integer_ty.isUnsigned()) { + return APInt::getMaxValue(width); + } else { + return APInt::getSignedMaxValue(width); + } + } + llvm_unreachable("invalid limit"); +} + +DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) { + RankedTensorType scalar_ty = RankedTensorType::get({}, ty); + if (auto float_ty = ty.dyn_cast()) { + return DenseElementsAttr::get(scalar_ty, + GetScalarLimitOfFloatType(float_ty, limit)); + } else if (auto integer_ty = ty.dyn_cast()) { + return DenseElementsAttr::get( + scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit)); + } + llvm_unreachable("unsupported type"); } } // namespace hlo