Updates LLVM usage to match
[202766947edb](https://github.com/llvm/llvm-project/commit/202766947edb)

PiperOrigin-RevId: 329673065
This commit is contained in:
A. Unique TensorFlower 2020-09-02 02:27:09 -07:00 committed by TensorFlow MLIR Team
parent 7c93352a40
commit 65b0613491
6 changed files with 20 additions and 21 deletions

View File

@ -1,2 +1,2 @@
5ffd940ac02a8e000691c45a6dc4f69d0198e675 202766947edb5407b84484185608aac077085608

View File

@ -172,7 +172,7 @@ static LogicalResult Verify(DotGeneralOp op) {
/// Fold get_dimension_size when the said shape dimension is a constant. /// Fold get_dimension_size when the said shape dimension is a constant.
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) { OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
RankedTensorType type = operand().getType().cast<RankedTensorType>(); RankedTensorType type = operand().getType().cast<RankedTensorType>();
int32_t dim = dimension().getSExtValue(); int32_t dim = dimension();
if (type.isDynamic(dim)) return {}; if (type.isDynamic(dim)) return {};
// The result type is always is a 0-d i32 tensor. // The result type is always is a 0-d i32 tensor.
return DenseIntElementsAttr::get<int32_t>( return DenseIntElementsAttr::get<int32_t>(
@ -190,7 +190,7 @@ static LogicalResult Verify(IotaOp op) {
if (shape.getRank() == 0) if (shape.getRank() == 0)
return op.emitOpError() << "does not support scalars."; return op.emitOpError() << "does not support scalars.";
auto iota_dimension = op.iota_dimension().getSExtValue(); auto iota_dimension = op.iota_dimension();
if (iota_dimension >= shape.getRank() || iota_dimension < 0) if (iota_dimension >= shape.getRank() || iota_dimension < 0)
return op.emitOpError() << "iota dimension cannot go beyond the output " return op.emitOpError() << "iota dimension cannot go beyond the output "
"rank or be negative."; "rank or be negative.";
@ -212,8 +212,7 @@ struct IotaBroadcast : public OpRewritePattern<IotaOp> {
auto iota_dimension = iota.iota_dimension(); auto iota_dimension = iota.iota_dimension();
auto iota_type = RankedTensorType::get( auto iota_type = RankedTensorType::get(
{result_ty.getDimSize(iota_dimension.getLimitedValue())}, {result_ty.getDimSize(iota_dimension)}, result_ty.getElementType());
result_ty.getElementType());
auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type, auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
rewriter.getI64IntegerAttr(0)); rewriter.getI64IntegerAttr(0));
@ -233,7 +232,7 @@ void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
} }
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) { OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
auto dimension = iota_dimension().getLimitedValue(); auto dimension = iota_dimension();
auto result_ty = getResult().getType().cast<ShapedType>(); auto result_ty = getResult().getType().cast<ShapedType>();
if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) { if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
Builder builder(getContext()); Builder builder(getContext());
@ -277,7 +276,7 @@ struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
} }
auto iota_dimension = iota.iota_dimension(); auto iota_dimension = iota.iota_dimension();
auto iota_dimension_int = iota_dimension.getLimitedValue(); auto iota_dimension_int = iota_dimension;
auto converted_shape = rewriter.create<IndexCastOp>( auto converted_shape = rewriter.create<IndexCastOp>(
iota.getLoc(), iota.getLoc(),
@ -476,7 +475,7 @@ static LogicalResult Verify(DequantizeOp op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(GetTupleElementOp op) { static LogicalResult Verify(GetTupleElementOp op) {
auto indexVal = op.index().getZExtValue(); auto indexVal = op.index();
auto operandType = op.getOperand().getType().cast<TupleType>(); auto operandType = op.getOperand().getType().cast<TupleType>();
if (indexVal >= operandType.size()) { if (indexVal >= operandType.size()) {
return op.emitOpError( return op.emitOpError(
@ -495,7 +494,7 @@ static LogicalResult Verify(GetTupleElementOp op) {
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) { OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
if (auto tupleOp = if (auto tupleOp =
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) { dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
return tupleOp.getOperand(index().getLimitedValue()); return tupleOp.getOperand(index());
} }
return {}; return {};
@ -565,8 +564,8 @@ static LogicalResult Verify(AllToAllOp op) {
// count. // count.
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>(); auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
if (!type) return success(); if (!type) return success();
auto split_dim_size = type.getDimSize(op.split_dimension().getSExtValue()); auto split_dim_size = type.getDimSize(op.split_dimension());
auto split_count = op.split_count().getSExtValue(); auto split_count = op.split_count();
if (split_dim_size % split_count != 0) { if (split_dim_size % split_count != 0) {
return op.emitError() << "split dimension has size " << split_dim_size return op.emitError() << "split dimension has size " << split_dim_size
<< ", expected to be a multiple of split_count " << ", expected to be a multiple of split_count "
@ -914,7 +913,7 @@ class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op, LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter& rewriter) const override { PatternRewriter& rewriter) const override {
auto axis = op.dimension().getLimitedValue(); auto axis = op.dimension();
llvm::SmallVector<Value, 6> new_operands; llvm::SmallVector<Value, 6> new_operands;
for (auto operand : op.getOperands()) { for (auto operand : op.getOperands()) {
auto ty = operand.getType().cast<ShapedType>(); auto ty = operand.getType().cast<ShapedType>();
@ -994,7 +993,7 @@ void ConcatenateOp::getCanonicalizationPatterns(
template <typename T> template <typename T>
static Attribute foldConcatenateHelper(ConcatenateOp* op, static Attribute foldConcatenateHelper(ConcatenateOp* op,
ArrayRef<Attribute> operands) { ArrayRef<Attribute> operands) {
auto axis = op->dimension().getLimitedValue(); auto axis = op->dimension();
auto type = op->getType().cast<ShapedType>(); auto type = op->getType().cast<ShapedType>();
SmallVector<T, 6> values; SmallVector<T, 6> values;
@ -1042,7 +1041,7 @@ OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
ShapedType type = getResult().getType().cast<ShapedType>(); ShapedType type = getResult().getType().cast<ShapedType>();
if (!type.hasStaticShape()) return {}; if (!type.hasStaticShape()) return {};
auto axis = dimension().getLimitedValue(); auto axis = dimension();
if (auto attr = foldConcatenate(this, operands)) { if (auto attr = foldConcatenate(this, operands)) {
return attr; return attr;
} }
@ -1845,7 +1844,7 @@ struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
return failure(); return failure();
} }
auto dimension = concat.dimension().getSExtValue(); auto dimension = concat.dimension();
auto start = slice.start_indices().getIntValues(); auto start = slice.start_indices().getIntValues();
auto limit = slice.limit_indices().getIntValues(); auto limit = slice.limit_indices().getIntValues();
@ -1995,7 +1994,7 @@ static LogicalResult Verify(SortOp op) {
return op.emitOpError("requires all inputs to have the same dimensions"); return op.emitOpError("requires all inputs to have the same dimensions");
int64_t rank = input_shape.size(); int64_t rank = input_shape.size();
int64_t cmp_dim = op.dimension().getSExtValue(); int64_t cmp_dim = op.dimension();
if (cmp_dim < -rank || cmp_dim >= rank) if (cmp_dim < -rank || cmp_dim >= rank)
return op.emitOpError("dimension attribute value must be in range [-") return op.emitOpError("dimension attribute value must be in range [-")
<< rank << ", " << rank << "), but found " << cmp_dim; << rank << ", " << rank << "), but found " << cmp_dim;

View File

@ -704,7 +704,7 @@ class IotaConverter : public OpConversionPattern<OpTy> {
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
ValueRange args) { ValueRange args) {
Value castOp = nestedBuilder.create<IndexCastOp>( Value castOp = nestedBuilder.create<IndexCastOp>(
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()], nestedLoc, ivs[iotaOp.iota_dimension()],
nestedBuilder.getIntegerType( nestedBuilder.getIntegerType(
resultElementType.getIntOrFloatBitWidth())); resultElementType.getIntOrFloatBitWidth()));
if (resultElementType.template isa<FloatType>()) { if (resultElementType.template isa<FloatType>()) {

View File

@ -117,7 +117,7 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto output_type = op.getType().cast<ShapedType>(); auto output_type = op.getType().cast<ShapedType>();
auto output_size = output_type.getNumElements(); auto output_size = output_type.getNumElements();
auto dimension = op.iota_dimension().getSExtValue(); auto dimension = op.iota_dimension();
auto max_dim_size = output_type.getDimSize(dimension); auto max_dim_size = output_type.getDimSize(dimension);
auto element_type = output_type.getElementType(); auto element_type = output_type.getElementType();

View File

@ -80,7 +80,7 @@ void MatchAndRewrite(WhileOp whileOp) {
// the external value is captured. // the external value is captured.
if (auto gte = val.getDefiningOp<GetTupleElementOp>()) { if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0}; if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
int index = gte.index().getSExtValue(); int index = gte.index();
return {tupleOp.getOperand(index), index}; return {tupleOp.getOperand(index), index};
} }
return {nullptr, 0}; return {nullptr, 0};
@ -154,7 +154,7 @@ void MatchAndRewrite(WhileOp whileOp) {
use->erase(); use->erase();
continue; continue;
} }
int index = gte.index().getSExtValue(); int index = gte.index();
// If after the loop induction variable, then decrement as we don't include // If after the loop induction variable, then decrement as we don't include
// the loop induction variable in the for iter operands. // the loop induction variable in the for iter operands.
if (index > loopIndVar.second) --index; if (index > loopIndVar.second) --index;

View File

@ -122,7 +122,7 @@ class UnfuseBatchNormInferencePattern
if (!fp_type) { if (!fp_type) {
return failure(); return failure();
} }
int64_t feature_dim = bn_op.feature_index().getSExtValue(); int64_t feature_dim = bn_op.feature_index();
// Add epsilon to the variance and sqrt to get stddev: // Add epsilon to the variance and sqrt to get stddev:
// stddev = sqrt(variance + epsilon) // stddev = sqrt(variance + epsilon)