diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 4fc3e7e..64b42c9 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -518,19 +518,11 @@ void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand, } OpFoldResult ConvertOp::fold(ArrayRef operands) { - auto operand_ty = getOperand().getType().cast(); - auto result_ty = getResult().getType().cast(); - if (operand_ty == result_ty) return getOperand(); + if (getOperand().getType() == getResult().getType()) return getOperand(); // If the result has non-static shape, a convert op is necessary to go from // static shape to non-static shape. - if (!result_ty.hasStaticShape()) return {}; - - // TODO(hinsu): Handle unsigned types. - if (operand_ty.getElementType().isUnsignedInteger() || - result_ty.getElementType().isUnsignedInteger()) { - return {}; - } + if (!getResult().getType().cast().hasStaticShape()) return {}; // If the operand is constant, we can do the conversion now. if (auto elementsAttr = operands.front().dyn_cast_or_null()) {