[mlir][hlo] Avoid dyn_cast_or_null when called with getDefiningOp result (NFC)

PiperOrigin-RevId: 376110457
This commit is contained in:
Adrian Kuegel 2021-05-27 00:18:35 -07:00 committed by TensorFlow MLIR Team
parent d939a156d8
commit a4fa6afa07
2 changed files with 13 additions and 18 deletions

View File

@ -639,9 +639,8 @@ static LogicalResult Verify(GetTupleElementOp op) {
}
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
if (auto tupleOp =
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
return tupleOp.getOperand(index());
if (auto tuple_op = getOperand().getDefiningOp<mhlo::TupleOp>()) {
return tuple_op.getOperand(index());
}
return {};
@ -679,8 +678,7 @@ struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
if (op.val().empty()) return failure();
Value first_element = op.val().front();
auto first_element_op =
dyn_cast_or_null<GetTupleElementOp>(first_element.getDefiningOp());
auto first_element_op = first_element.getDefiningOp<GetTupleElementOp>();
if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
return failure();
@ -688,8 +686,8 @@ struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
if (tuple_predecessor.getType() != op.getType()) return failure();
for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
auto element_op = dyn_cast_or_null<GetTupleElementOp>(
element_and_idx.value().getDefiningOp());
auto element_op =
element_and_idx.value().getDefiningOp<GetTupleElementOp>();
if (!element_op ||
element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
element_op.getOperand() != tuple_predecessor)
@ -1060,8 +1058,8 @@ LogicalResult ComplexOp::inferReturnTypes(
}
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp());
auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp());
auto real_op = getOperand(0).getDefiningOp<mhlo::RealOp>();
auto imag_op = getOperand(1).getDefiningOp<mhlo::ImagOp>();
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
return real_op.getOperand();
}
@ -1098,8 +1096,7 @@ LogicalResult ImagOp::inferReturnTypes(
}
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op =
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
return complex_op.getOperand(1);
}
@ -1141,8 +1138,7 @@ LogicalResult RealOp::inferReturnTypes(
}
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op =
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
return complex_op.getOperand(0);
}
@ -2378,8 +2374,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
return getOperand();
}
if (auto prev_op =
dyn_cast_or_null<ReshapeOp>(getOperand().getDefiningOp())) {
if (auto prev_op = getOperand().getDefiningOp<ReshapeOp>()) {
setOperand(prev_op.getOperand());
return getResult();
}
@ -2954,7 +2949,7 @@ struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
auto slice_input = slice.operand();
auto slice_input_ty = slice_input.getType().cast<ShapedType>();
auto concat = dyn_cast_or_null<ConcatenateOp>(slice_input.getDefiningOp());
auto concat = slice_input.getDefiningOp<ConcatenateOp>();
if (!concat) {
return failure();
}

View File

@ -70,8 +70,8 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1) return failure();
auto defining_op = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(
op->getOperand(0).getDefiningOp());
auto defining_op =
op->getOperand(0).getDefiningOp<InferShapedTypeOpInterface>();
if (!defining_op) return failure();
SmallVector<Value, 4> return_shapes;
if (failed(defining_op.reifyReturnTypeShapes(