Add canonicalization for unpacking and repacking the same tuple (e.g. tuple -> get_tuple_element -> tuple).
These unpacking and repacking of tuples may be generated when modifying tuple arguments or results. PiperOrigin-RevId: 325162694
This commit is contained in:
parent
ad12e06ceb
commit
c340367702
|
@ -671,6 +671,7 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
|||
"OpBuilder &builder, OperationState &results, "
|
||||
"ValueRange values">];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def HLO_CompareOp: HLO_Op<"compare",
|
||||
|
|
|
@ -506,6 +506,46 @@ static LogicalResult Verify(TupleOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Pattern for unpacking and repacking the same tuple.
|
||||
struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
|
||||
using OpRewritePattern<TupleOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TupleOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
if (op.val().empty()) return failure();
|
||||
|
||||
Value first_element = op.val().front();
|
||||
auto first_element_op =
|
||||
dyn_cast_or_null<GetTupleElementOp>(first_element.getDefiningOp());
|
||||
if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
|
||||
return failure();
|
||||
|
||||
Value tuple_predecessor = first_element_op.getOperand();
|
||||
if (tuple_predecessor.getType() != op.getType()) return failure();
|
||||
|
||||
for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
|
||||
auto element_op = dyn_cast_or_null<GetTupleElementOp>(
|
||||
element_and_idx.value().getDefiningOp());
|
||||
if (!element_op ||
|
||||
element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
|
||||
element_op.getOperand() != tuple_predecessor)
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, tuple_predecessor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||
MLIRContext* context) {
|
||||
results.insert<UnpackRepackSameTuple>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllToAllOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -561,3 +561,25 @@ func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
|
|||
|
||||
return %arg0 : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: unpack_repack_same_tuple
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tuple<tensor<i32>, !mhlo.token, tensor<f32>>)
|
||||
func @unpack_repack_same_tuple(%arg0: tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tuple<tensor<i32>, !mhlo.token, tensor<f32>> {
|
||||
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tensor<i32>
|
||||
%1 = "mhlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> !mhlo.token
|
||||
%2 = "mhlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tensor<f32>
|
||||
%3 = "mhlo.tuple"(%0, %1, %2) : (tensor<i32>, !mhlo.token, tensor<f32>) -> tuple<tensor<i32>, !mhlo.token, tensor<f32>>
|
||||
|
||||
// CHECK: return [[ARG0]]
|
||||
return %3 : tuple<tensor<i32>, !mhlo.token, tensor<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: unpack_repack_same_tuple_single_element
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tuple<tensor<i32>>)
|
||||
func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tuple<tensor<i32>> {
|
||||
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
|
||||
%3 = "mhlo.tuple"(%0) : (tensor<i32>) -> tuple<tensor<i32>>
|
||||
|
||||
// CHECK: return [[ARG0]]
|
||||
return %3 : tuple<tensor<i32>>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue