Fix handling of negative seeds in random number generator op kernels for XLA

Casting negative s32 number to u64 directly will have leading 1s in the representation which is not what we want to get a single u64 out of two s32 seeds. Fixed this by first getting unsigned number of the same bit-width.

PiperOrigin-RevId: 345902167
This commit is contained in:
Smit Hinsu 2020-12-05 18:54:37 -08:00 committed by TensorFlow MLIR Team
parent 55268f9ee8
commit bc7b6374c8
1 changed files with 10 additions and 2 deletions

View File

@ -518,11 +518,19 @@ void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
} }
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
if (getOperand().getType() == getResult().getType()) return getOperand(); auto operand_ty = getOperand().getType().cast<TensorType>();
auto result_ty = getResult().getType().cast<TensorType>();
if (operand_ty == result_ty) return getOperand();
// If the result has non-static shape, a convert op is necessary to go from // If the result has non-static shape, a convert op is necessary to go from
// static shape to non-static shape. // static shape to non-static shape.
if (!getResult().getType().cast<TensorType>().hasStaticShape()) return {}; if (!result_ty.hasStaticShape()) return {};
// TODO(hinsu): Handle unsigned types.
if (operand_ty.getElementType().isUnsignedInteger() ||
result_ty.getElementType().isUnsignedInteger()) {
return {};
}
// If the operand is constant, we can do the conversion now. // If the operand is constant, we can do the conversion now.
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) { if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {