diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 64b42c9..4fc3e7e 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -518,11 +518,19 @@ void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand, } OpFoldResult ConvertOp::fold(ArrayRef operands) { - if (getOperand().getType() == getResult().getType()) return getOperand(); + auto operand_ty = getOperand().getType().cast(); + auto result_ty = getResult().getType().cast(); + if (operand_ty == result_ty) return getOperand(); // If the result has non-static shape, a convert op is necessary to go from // static shape to non-static shape. - if (!getResult().getType().cast().hasStaticShape()) return {}; + if (!result_ty.hasStaticShape()) return {}; + + // TODO(hinsu): Handle unsigned types. + if (operand_ty.getElementType().isUnsignedInteger() || + result_ty.getElementType().isUnsignedInteger()) { + return {}; + } // If the operand is constant, we can do the conversion now. if (auto elementsAttr = operands.front().dyn_cast_or_null()) {