[mlir][hlo] Avoid dyn_cast_or_null when called with getDefiningOp result (NFC)
PiperOrigin-RevId: 376110457
This commit is contained in:
		
							parent
							
								
									d939a156d8
								
							
						
					
					
						commit
						a4fa6afa07
					
				|  | @ -639,9 +639,8 @@ static LogicalResult Verify(GetTupleElementOp op) { | |||
| } | ||||
| 
 | ||||
| OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) { | ||||
|   if (auto tupleOp = | ||||
|           dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) { | ||||
|     return tupleOp.getOperand(index()); | ||||
|   if (auto tuple_op = getOperand().getDefiningOp<mhlo::TupleOp>()) { | ||||
|     return tuple_op.getOperand(index()); | ||||
|   } | ||||
| 
 | ||||
|   return {}; | ||||
|  | @ -679,8 +678,7 @@ struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> { | |||
|     if (op.val().empty()) return failure(); | ||||
| 
 | ||||
|     Value first_element = op.val().front(); | ||||
|     auto first_element_op = | ||||
|         dyn_cast_or_null<GetTupleElementOp>(first_element.getDefiningOp()); | ||||
|     auto first_element_op = first_element.getDefiningOp<GetTupleElementOp>(); | ||||
|     if (!first_element_op || first_element_op.indexAttr().getInt() != 0) | ||||
|       return failure(); | ||||
| 
 | ||||
|  | @ -688,8 +686,8 @@ struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> { | |||
|     if (tuple_predecessor.getType() != op.getType()) return failure(); | ||||
| 
 | ||||
|     for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) { | ||||
|       auto element_op = dyn_cast_or_null<GetTupleElementOp>( | ||||
|           element_and_idx.value().getDefiningOp()); | ||||
|       auto element_op = | ||||
|           element_and_idx.value().getDefiningOp<GetTupleElementOp>(); | ||||
|       if (!element_op || | ||||
|           element_op.indexAttr().getInt() != element_and_idx.index() + 1 || | ||||
|           element_op.getOperand() != tuple_predecessor) | ||||
|  | @ -1060,8 +1058,8 @@ LogicalResult ComplexOp::inferReturnTypes( | |||
| } | ||||
| 
 | ||||
| OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) { | ||||
|   auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp()); | ||||
|   auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp()); | ||||
|   auto real_op = getOperand(0).getDefiningOp<mhlo::RealOp>(); | ||||
|   auto imag_op = getOperand(1).getDefiningOp<mhlo::ImagOp>(); | ||||
|   if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) { | ||||
|     return real_op.getOperand(); | ||||
|   } | ||||
|  | @ -1098,8 +1096,7 @@ LogicalResult ImagOp::inferReturnTypes( | |||
| } | ||||
| 
 | ||||
| OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) { | ||||
|   if (auto complex_op = | ||||
|           dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) { | ||||
|   if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) { | ||||
|     return complex_op.getOperand(1); | ||||
|   } | ||||
| 
 | ||||
|  | @ -1141,8 +1138,7 @@ LogicalResult RealOp::inferReturnTypes( | |||
| } | ||||
| 
 | ||||
| OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) { | ||||
|   if (auto complex_op = | ||||
|           dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) { | ||||
|   if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) { | ||||
|     return complex_op.getOperand(0); | ||||
|   } | ||||
| 
 | ||||
|  | @ -2378,8 +2374,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) { | |||
|     return getOperand(); | ||||
|   } | ||||
| 
 | ||||
|   if (auto prev_op = | ||||
|           dyn_cast_or_null<ReshapeOp>(getOperand().getDefiningOp())) { | ||||
|   if (auto prev_op = getOperand().getDefiningOp<ReshapeOp>()) { | ||||
|     setOperand(prev_op.getOperand()); | ||||
|     return getResult(); | ||||
|   } | ||||
|  | @ -2954,7 +2949,7 @@ struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> { | |||
| 
 | ||||
|     auto slice_input = slice.operand(); | ||||
|     auto slice_input_ty = slice_input.getType().cast<ShapedType>(); | ||||
|     auto concat = dyn_cast_or_null<ConcatenateOp>(slice_input.getDefiningOp()); | ||||
|     auto concat = slice_input.getDefiningOp<ConcatenateOp>(); | ||||
|     if (!concat) { | ||||
|       return failure(); | ||||
|     } | ||||
|  |  | |||
|  | @ -70,8 +70,8 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern { | |||
|   LogicalResult matchAndRewrite(Operation *op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
|     if (op->getNumOperands() != 1) return failure(); | ||||
|     auto defining_op = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>( | ||||
|         op->getOperand(0).getDefiningOp()); | ||||
|     auto defining_op = | ||||
|         op->getOperand(0).getDefiningOp<InferShapedTypeOpInterface>(); | ||||
|     if (!defining_op) return failure(); | ||||
|     SmallVector<Value, 4> return_shapes; | ||||
|     if (failed(defining_op.reifyReturnTypeShapes( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue