From a4fa6afa0738569e69fdf9f205888f0a23f3fe09 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 27 May 2021 00:18:35 -0700 Subject: [PATCH] [mlir][hlo] Avoid dyn_cast_or_null when called with getDefiningOp result (NFC) PiperOrigin-RevId: 376110457 --- lib/Dialect/mhlo/IR/hlo_ops.cc | 27 ++++++++----------- .../transforms/test_infer_shaped_type_pass.cc | 4 +-- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index eb2bb35..f91bf4b 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -639,9 +639,8 @@ static LogicalResult Verify(GetTupleElementOp op) { } OpFoldResult GetTupleElementOp::fold(ArrayRef operands) { - if (auto tupleOp = - dyn_cast_or_null(getOperand().getDefiningOp())) { - return tupleOp.getOperand(index()); + if (auto tuple_op = getOperand().getDefiningOp()) { + return tuple_op.getOperand(index()); } return {}; @@ -679,8 +678,7 @@ struct UnpackRepackSameTuple : public OpRewritePattern { if (op.val().empty()) return failure(); Value first_element = op.val().front(); - auto first_element_op = - dyn_cast_or_null(first_element.getDefiningOp()); + auto first_element_op = first_element.getDefiningOp(); if (!first_element_op || first_element_op.indexAttr().getInt() != 0) return failure(); @@ -688,8 +686,8 @@ struct UnpackRepackSameTuple : public OpRewritePattern { if (tuple_predecessor.getType() != op.getType()) return failure(); for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) { - auto element_op = dyn_cast_or_null( - element_and_idx.value().getDefiningOp()); + auto element_op = + element_and_idx.value().getDefiningOp(); if (!element_op || element_op.indexAttr().getInt() != element_and_idx.index() + 1 || element_op.getOperand() != tuple_predecessor) @@ -1060,8 +1058,8 @@ LogicalResult ComplexOp::inferReturnTypes( } OpFoldResult ComplexOp::fold(ArrayRef operands) { - auto real_op = dyn_cast_or_null(getOperand(0).getDefiningOp()); - auto imag_op = dyn_cast_or_null(getOperand(1).getDefiningOp()); + auto real_op = getOperand(0).getDefiningOp(); + auto imag_op = getOperand(1).getDefiningOp(); if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) { return real_op.getOperand(); } @@ -1098,8 +1096,7 @@ LogicalResult ImagOp::inferReturnTypes( } OpFoldResult ImagOp::fold(ArrayRef operands) { - if (auto complex_op = - dyn_cast_or_null(getOperand().getDefiningOp())) { + if (auto complex_op = getOperand().getDefiningOp()) { return complex_op.getOperand(1); } @@ -1141,8 +1138,7 @@ LogicalResult RealOp::inferReturnTypes( } OpFoldResult RealOp::fold(ArrayRef operands) { - if (auto complex_op = - dyn_cast_or_null(getOperand().getDefiningOp())) { + if (auto complex_op = getOperand().getDefiningOp()) { return complex_op.getOperand(0); } @@ -2378,8 +2374,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { return getOperand(); } - if (auto prev_op = - dyn_cast_or_null(getOperand().getDefiningOp())) { + if (auto prev_op = getOperand().getDefiningOp()) { setOperand(prev_op.getOperand()); return getResult(); } @@ -2954,7 +2949,7 @@ struct SimplifyConcatSlice : public OpRewritePattern { auto slice_input = slice.operand(); auto slice_input_ty = slice_input.getType().cast(); - auto concat = dyn_cast_or_null(slice_input.getDefiningOp()); + auto concat = slice_input.getDefiningOp(); if (!concat) { return failure(); } diff --git a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc index d9588b3..251cc8b 100644 --- a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc +++ b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc @@ -70,8 +70,8 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getNumOperands() != 1) return failure(); - auto defining_op = llvm::dyn_cast_or_null( - op->getOperand(0).getDefiningOp()); + auto defining_op = + op->getOperand(0).getDefiningOp(); if (!defining_op) return failure(); SmallVector return_shapes; if (failed(defining_op.reifyReturnTypeShapes(