[mlir][hlo] Avoid dyn_cast_or_null when called with getDefiningOp result (NFC)
PiperOrigin-RevId: 376110457
This commit is contained in:
parent
d939a156d8
commit
a4fa6afa07
|
@ -639,9 +639,8 @@ static LogicalResult Verify(GetTupleElementOp op) {
|
|||
}
|
||||
|
||||
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto tupleOp =
|
||||
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
|
||||
return tupleOp.getOperand(index());
|
||||
if (auto tuple_op = getOperand().getDefiningOp<mhlo::TupleOp>()) {
|
||||
return tuple_op.getOperand(index());
|
||||
}
|
||||
|
||||
return {};
|
||||
|
@ -679,8 +678,7 @@ struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
|
|||
if (op.val().empty()) return failure();
|
||||
|
||||
Value first_element = op.val().front();
|
||||
auto first_element_op =
|
||||
dyn_cast_or_null<GetTupleElementOp>(first_element.getDefiningOp());
|
||||
auto first_element_op = first_element.getDefiningOp<GetTupleElementOp>();
|
||||
if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
|
||||
return failure();
|
||||
|
||||
|
@ -688,8 +686,8 @@ struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
|
|||
if (tuple_predecessor.getType() != op.getType()) return failure();
|
||||
|
||||
for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
|
||||
auto element_op = dyn_cast_or_null<GetTupleElementOp>(
|
||||
element_and_idx.value().getDefiningOp());
|
||||
auto element_op =
|
||||
element_and_idx.value().getDefiningOp<GetTupleElementOp>();
|
||||
if (!element_op ||
|
||||
element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
|
||||
element_op.getOperand() != tuple_predecessor)
|
||||
|
@ -1060,8 +1058,8 @@ LogicalResult ComplexOp::inferReturnTypes(
|
|||
}
|
||||
|
||||
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp());
|
||||
auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp());
|
||||
auto real_op = getOperand(0).getDefiningOp<mhlo::RealOp>();
|
||||
auto imag_op = getOperand(1).getDefiningOp<mhlo::ImagOp>();
|
||||
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
|
||||
return real_op.getOperand();
|
||||
}
|
||||
|
@ -1098,8 +1096,7 @@ LogicalResult ImagOp::inferReturnTypes(
|
|||
}
|
||||
|
||||
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto complex_op =
|
||||
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
|
||||
return complex_op.getOperand(1);
|
||||
}
|
||||
|
||||
|
@ -1141,8 +1138,7 @@ LogicalResult RealOp::inferReturnTypes(
|
|||
}
|
||||
|
||||
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto complex_op =
|
||||
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
|
||||
return complex_op.getOperand(0);
|
||||
}
|
||||
|
||||
|
@ -2378,8 +2374,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
|||
return getOperand();
|
||||
}
|
||||
|
||||
if (auto prev_op =
|
||||
dyn_cast_or_null<ReshapeOp>(getOperand().getDefiningOp())) {
|
||||
if (auto prev_op = getOperand().getDefiningOp<ReshapeOp>()) {
|
||||
setOperand(prev_op.getOperand());
|
||||
return getResult();
|
||||
}
|
||||
|
@ -2954,7 +2949,7 @@ struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
|
|||
|
||||
auto slice_input = slice.operand();
|
||||
auto slice_input_ty = slice_input.getType().cast<ShapedType>();
|
||||
auto concat = dyn_cast_or_null<ConcatenateOp>(slice_input.getDefiningOp());
|
||||
auto concat = slice_input.getDefiningOp<ConcatenateOp>();
|
||||
if (!concat) {
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -70,8 +70,8 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
|
|||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->getNumOperands() != 1) return failure();
|
||||
auto defining_op = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(
|
||||
op->getOperand(0).getDefiningOp());
|
||||
auto defining_op =
|
||||
op->getOperand(0).getDefiningOp<InferShapedTypeOpInterface>();
|
||||
if (!defining_op) return failure();
|
||||
SmallVector<Value, 4> return_shapes;
|
||||
if (failed(defining_op.reifyReturnTypeShapes(
|
||||
|
|
Loading…
Reference in New Issue