[mlir][hlo] Avoid dyn_cast_or_null when called with getDefiningOp result (NFC)

PiperOrigin-RevId: 376110457
This commit is contained in:
Adrian Kuegel 2021-05-27 00:18:35 -07:00 committed by TensorFlow MLIR Team
parent d939a156d8
commit a4fa6afa07
2 changed files with 13 additions and 18 deletions

View File

@ -639,9 +639,8 @@ static LogicalResult Verify(GetTupleElementOp op) {
} }
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) { OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
if (auto tupleOp = if (auto tuple_op = getOperand().getDefiningOp<mhlo::TupleOp>()) {
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) { return tuple_op.getOperand(index());
return tupleOp.getOperand(index());
} }
return {}; return {};
@ -679,8 +678,7 @@ struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
if (op.val().empty()) return failure(); if (op.val().empty()) return failure();
Value first_element = op.val().front(); Value first_element = op.val().front();
auto first_element_op = auto first_element_op = first_element.getDefiningOp<GetTupleElementOp>();
dyn_cast_or_null<GetTupleElementOp>(first_element.getDefiningOp());
if (!first_element_op || first_element_op.indexAttr().getInt() != 0) if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
return failure(); return failure();
@ -688,8 +686,8 @@ struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
if (tuple_predecessor.getType() != op.getType()) return failure(); if (tuple_predecessor.getType() != op.getType()) return failure();
for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) { for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
auto element_op = dyn_cast_or_null<GetTupleElementOp>( auto element_op =
element_and_idx.value().getDefiningOp()); element_and_idx.value().getDefiningOp<GetTupleElementOp>();
if (!element_op || if (!element_op ||
element_op.indexAttr().getInt() != element_and_idx.index() + 1 || element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
element_op.getOperand() != tuple_predecessor) element_op.getOperand() != tuple_predecessor)
@ -1060,8 +1058,8 @@ LogicalResult ComplexOp::inferReturnTypes(
} }
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp()); auto real_op = getOperand(0).getDefiningOp<mhlo::RealOp>();
auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp()); auto imag_op = getOperand(1).getDefiningOp<mhlo::ImagOp>();
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) { if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
return real_op.getOperand(); return real_op.getOperand();
} }
@ -1098,8 +1096,7 @@ LogicalResult ImagOp::inferReturnTypes(
} }
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op = if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(1); return complex_op.getOperand(1);
} }
@ -1141,8 +1138,7 @@ LogicalResult RealOp::inferReturnTypes(
} }
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) { OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op = if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(0); return complex_op.getOperand(0);
} }
@ -2378,8 +2374,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
return getOperand(); return getOperand();
} }
if (auto prev_op = if (auto prev_op = getOperand().getDefiningOp<ReshapeOp>()) {
dyn_cast_or_null<ReshapeOp>(getOperand().getDefiningOp())) {
setOperand(prev_op.getOperand()); setOperand(prev_op.getOperand());
return getResult(); return getResult();
} }
@ -2954,7 +2949,7 @@ struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
auto slice_input = slice.operand(); auto slice_input = slice.operand();
auto slice_input_ty = slice_input.getType().cast<ShapedType>(); auto slice_input_ty = slice_input.getType().cast<ShapedType>();
auto concat = dyn_cast_or_null<ConcatenateOp>(slice_input.getDefiningOp()); auto concat = slice_input.getDefiningOp<ConcatenateOp>();
if (!concat) { if (!concat) {
return failure(); return failure();
} }

View File

@ -70,8 +70,8 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op, LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1) return failure(); if (op->getNumOperands() != 1) return failure();
auto defining_op = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>( auto defining_op =
op->getOperand(0).getDefiningOp()); op->getOperand(0).getDefiningOp<InferShapedTypeOpInterface>();
if (!defining_op) return failure(); if (!defining_op) return failure();
SmallVector<Value, 4> return_shapes; SmallVector<Value, 4> return_shapes;
if (failed(defining_op.reifyReturnTypeShapes( if (failed(defining_op.reifyReturnTypeShapes(