Integrate LLVM at llvm/llvm-project@202766947e
Updates LLVM usage to match [202766947edb](https://github.com/llvm/llvm-project/commit/202766947edb) PiperOrigin-RevId: 329673065
This commit is contained in:
parent
7c93352a40
commit
65b0613491
|
@ -1,2 +1,2 @@
|
||||||
5ffd940ac02a8e000691c45a6dc4f69d0198e675
|
202766947edb5407b84484185608aac077085608
|
||||||
|
|
||||||
|
|
|
@ -172,7 +172,7 @@ static LogicalResult Verify(DotGeneralOp op) {
|
||||||
/// Fold get_dimension_size when the said shape dimension is a constant.
|
/// Fold get_dimension_size when the said shape dimension is a constant.
|
||||||
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
|
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
|
||||||
RankedTensorType type = operand().getType().cast<RankedTensorType>();
|
RankedTensorType type = operand().getType().cast<RankedTensorType>();
|
||||||
int32_t dim = dimension().getSExtValue();
|
int32_t dim = dimension();
|
||||||
if (type.isDynamic(dim)) return {};
|
if (type.isDynamic(dim)) return {};
|
||||||
// The result type is always is a 0-d i32 tensor.
|
// The result type is always is a 0-d i32 tensor.
|
||||||
return DenseIntElementsAttr::get<int32_t>(
|
return DenseIntElementsAttr::get<int32_t>(
|
||||||
|
@ -190,7 +190,7 @@ static LogicalResult Verify(IotaOp op) {
|
||||||
if (shape.getRank() == 0)
|
if (shape.getRank() == 0)
|
||||||
return op.emitOpError() << "does not support scalars.";
|
return op.emitOpError() << "does not support scalars.";
|
||||||
|
|
||||||
auto iota_dimension = op.iota_dimension().getSExtValue();
|
auto iota_dimension = op.iota_dimension();
|
||||||
if (iota_dimension >= shape.getRank() || iota_dimension < 0)
|
if (iota_dimension >= shape.getRank() || iota_dimension < 0)
|
||||||
return op.emitOpError() << "iota dimension cannot go beyond the output "
|
return op.emitOpError() << "iota dimension cannot go beyond the output "
|
||||||
"rank or be negative.";
|
"rank or be negative.";
|
||||||
|
@ -212,8 +212,7 @@ struct IotaBroadcast : public OpRewritePattern<IotaOp> {
|
||||||
auto iota_dimension = iota.iota_dimension();
|
auto iota_dimension = iota.iota_dimension();
|
||||||
|
|
||||||
auto iota_type = RankedTensorType::get(
|
auto iota_type = RankedTensorType::get(
|
||||||
{result_ty.getDimSize(iota_dimension.getLimitedValue())},
|
{result_ty.getDimSize(iota_dimension)}, result_ty.getElementType());
|
||||||
result_ty.getElementType());
|
|
||||||
|
|
||||||
auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
|
auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
|
||||||
rewriter.getI64IntegerAttr(0));
|
rewriter.getI64IntegerAttr(0));
|
||||||
|
@ -233,7 +232,7 @@ void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
|
||||||
auto dimension = iota_dimension().getLimitedValue();
|
auto dimension = iota_dimension();
|
||||||
auto result_ty = getResult().getType().cast<ShapedType>();
|
auto result_ty = getResult().getType().cast<ShapedType>();
|
||||||
if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
|
if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
|
||||||
Builder builder(getContext());
|
Builder builder(getContext());
|
||||||
|
@ -277,7 +276,7 @@ struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto iota_dimension = iota.iota_dimension();
|
auto iota_dimension = iota.iota_dimension();
|
||||||
auto iota_dimension_int = iota_dimension.getLimitedValue();
|
auto iota_dimension_int = iota_dimension;
|
||||||
|
|
||||||
auto converted_shape = rewriter.create<IndexCastOp>(
|
auto converted_shape = rewriter.create<IndexCastOp>(
|
||||||
iota.getLoc(),
|
iota.getLoc(),
|
||||||
|
@ -476,7 +475,7 @@ static LogicalResult Verify(DequantizeOp op) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult Verify(GetTupleElementOp op) {
|
static LogicalResult Verify(GetTupleElementOp op) {
|
||||||
auto indexVal = op.index().getZExtValue();
|
auto indexVal = op.index();
|
||||||
auto operandType = op.getOperand().getType().cast<TupleType>();
|
auto operandType = op.getOperand().getType().cast<TupleType>();
|
||||||
if (indexVal >= operandType.size()) {
|
if (indexVal >= operandType.size()) {
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
|
@ -495,7 +494,7 @@ static LogicalResult Verify(GetTupleElementOp op) {
|
||||||
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (auto tupleOp =
|
if (auto tupleOp =
|
||||||
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
|
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
|
||||||
return tupleOp.getOperand(index().getLimitedValue());
|
return tupleOp.getOperand(index());
|
||||||
}
|
}
|
||||||
|
|
||||||
return {};
|
return {};
|
||||||
|
@ -565,8 +564,8 @@ static LogicalResult Verify(AllToAllOp op) {
|
||||||
// count.
|
// count.
|
||||||
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!type) return success();
|
if (!type) return success();
|
||||||
auto split_dim_size = type.getDimSize(op.split_dimension().getSExtValue());
|
auto split_dim_size = type.getDimSize(op.split_dimension());
|
||||||
auto split_count = op.split_count().getSExtValue();
|
auto split_count = op.split_count();
|
||||||
if (split_dim_size % split_count != 0) {
|
if (split_dim_size % split_count != 0) {
|
||||||
return op.emitError() << "split dimension has size " << split_dim_size
|
return op.emitError() << "split dimension has size " << split_dim_size
|
||||||
<< ", expected to be a multiple of split_count "
|
<< ", expected to be a multiple of split_count "
|
||||||
|
@ -914,7 +913,7 @@ class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(ConcatenateOp op,
|
LogicalResult matchAndRewrite(ConcatenateOp op,
|
||||||
PatternRewriter& rewriter) const override {
|
PatternRewriter& rewriter) const override {
|
||||||
auto axis = op.dimension().getLimitedValue();
|
auto axis = op.dimension();
|
||||||
llvm::SmallVector<Value, 6> new_operands;
|
llvm::SmallVector<Value, 6> new_operands;
|
||||||
for (auto operand : op.getOperands()) {
|
for (auto operand : op.getOperands()) {
|
||||||
auto ty = operand.getType().cast<ShapedType>();
|
auto ty = operand.getType().cast<ShapedType>();
|
||||||
|
@ -994,7 +993,7 @@ void ConcatenateOp::getCanonicalizationPatterns(
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static Attribute foldConcatenateHelper(ConcatenateOp* op,
|
static Attribute foldConcatenateHelper(ConcatenateOp* op,
|
||||||
ArrayRef<Attribute> operands) {
|
ArrayRef<Attribute> operands) {
|
||||||
auto axis = op->dimension().getLimitedValue();
|
auto axis = op->dimension();
|
||||||
auto type = op->getType().cast<ShapedType>();
|
auto type = op->getType().cast<ShapedType>();
|
||||||
|
|
||||||
SmallVector<T, 6> values;
|
SmallVector<T, 6> values;
|
||||||
|
@ -1042,7 +1041,7 @@ OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
|
||||||
ShapedType type = getResult().getType().cast<ShapedType>();
|
ShapedType type = getResult().getType().cast<ShapedType>();
|
||||||
if (!type.hasStaticShape()) return {};
|
if (!type.hasStaticShape()) return {};
|
||||||
|
|
||||||
auto axis = dimension().getLimitedValue();
|
auto axis = dimension();
|
||||||
if (auto attr = foldConcatenate(this, operands)) {
|
if (auto attr = foldConcatenate(this, operands)) {
|
||||||
return attr;
|
return attr;
|
||||||
}
|
}
|
||||||
|
@ -1845,7 +1844,7 @@ struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto dimension = concat.dimension().getSExtValue();
|
auto dimension = concat.dimension();
|
||||||
|
|
||||||
auto start = slice.start_indices().getIntValues();
|
auto start = slice.start_indices().getIntValues();
|
||||||
auto limit = slice.limit_indices().getIntValues();
|
auto limit = slice.limit_indices().getIntValues();
|
||||||
|
@ -1995,7 +1994,7 @@ static LogicalResult Verify(SortOp op) {
|
||||||
return op.emitOpError("requires all inputs to have the same dimensions");
|
return op.emitOpError("requires all inputs to have the same dimensions");
|
||||||
|
|
||||||
int64_t rank = input_shape.size();
|
int64_t rank = input_shape.size();
|
||||||
int64_t cmp_dim = op.dimension().getSExtValue();
|
int64_t cmp_dim = op.dimension();
|
||||||
if (cmp_dim < -rank || cmp_dim >= rank)
|
if (cmp_dim < -rank || cmp_dim >= rank)
|
||||||
return op.emitOpError("dimension attribute value must be in range [-")
|
return op.emitOpError("dimension attribute value must be in range [-")
|
||||||
<< rank << ", " << rank << "), but found " << cmp_dim;
|
<< rank << ", " << rank << "), but found " << cmp_dim;
|
||||||
|
|
|
@ -704,7 +704,7 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
|
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
|
||||||
ValueRange args) {
|
ValueRange args) {
|
||||||
Value castOp = nestedBuilder.create<IndexCastOp>(
|
Value castOp = nestedBuilder.create<IndexCastOp>(
|
||||||
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
|
nestedLoc, ivs[iotaOp.iota_dimension()],
|
||||||
nestedBuilder.getIntegerType(
|
nestedBuilder.getIntegerType(
|
||||||
resultElementType.getIntOrFloatBitWidth()));
|
resultElementType.getIntOrFloatBitWidth()));
|
||||||
if (resultElementType.template isa<FloatType>()) {
|
if (resultElementType.template isa<FloatType>()) {
|
||||||
|
|
|
@ -117,7 +117,7 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto output_type = op.getType().cast<ShapedType>();
|
auto output_type = op.getType().cast<ShapedType>();
|
||||||
auto output_size = output_type.getNumElements();
|
auto output_size = output_type.getNumElements();
|
||||||
auto dimension = op.iota_dimension().getSExtValue();
|
auto dimension = op.iota_dimension();
|
||||||
auto max_dim_size = output_type.getDimSize(dimension);
|
auto max_dim_size = output_type.getDimSize(dimension);
|
||||||
|
|
||||||
auto element_type = output_type.getElementType();
|
auto element_type = output_type.getElementType();
|
||||||
|
|
|
@ -80,7 +80,7 @@ void MatchAndRewrite(WhileOp whileOp) {
|
||||||
// the external value is captured.
|
// the external value is captured.
|
||||||
if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
|
if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
|
||||||
if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
|
if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
|
||||||
int index = gte.index().getSExtValue();
|
int index = gte.index();
|
||||||
return {tupleOp.getOperand(index), index};
|
return {tupleOp.getOperand(index), index};
|
||||||
}
|
}
|
||||||
return {nullptr, 0};
|
return {nullptr, 0};
|
||||||
|
@ -154,7 +154,7 @@ void MatchAndRewrite(WhileOp whileOp) {
|
||||||
use->erase();
|
use->erase();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
int index = gte.index().getSExtValue();
|
int index = gte.index();
|
||||||
// If after the loop induction variable, then decrement as we don't include
|
// If after the loop induction variable, then decrement as we don't include
|
||||||
// the loop induction variable in the for iter operands.
|
// the loop induction variable in the for iter operands.
|
||||||
if (index > loopIndVar.second) --index;
|
if (index > loopIndVar.second) --index;
|
||||||
|
|
|
@ -122,7 +122,7 @@ class UnfuseBatchNormInferencePattern
|
||||||
if (!fp_type) {
|
if (!fp_type) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
int64_t feature_dim = bn_op.feature_index().getSExtValue();
|
int64_t feature_dim = bn_op.feature_index();
|
||||||
|
|
||||||
// Add epsilon to the variance and sqrt to get stddev:
|
// Add epsilon to the variance and sqrt to get stddev:
|
||||||
// stddev = sqrt(variance + epsilon)
|
// stddev = sqrt(variance + epsilon)
|
||||||
|
|
Loading…
Reference in New Issue