Add canonicalization for unpacking and repacking the same tuple (e.g. tuple -> get_tuple_element -> tuple).
These unpacking and repacking of tuples may be generated when modifying tuple arguments or results. PiperOrigin-RevId: 325162694
This commit is contained in:
parent
ad12e06ceb
commit
c340367702
|
@ -671,6 +671,7 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
||||||
"OpBuilder &builder, OperationState &results, "
|
"OpBuilder &builder, OperationState &results, "
|
||||||
"ValueRange values">];
|
"ValueRange values">];
|
||||||
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_CompareOp: HLO_Op<"compare",
|
def HLO_CompareOp: HLO_Op<"compare",
|
||||||
|
|
|
@ -506,6 +506,46 @@ static LogicalResult Verify(TupleOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Pattern for unpacking and repacking the same tuple.
|
||||||
|
struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
|
||||||
|
using OpRewritePattern<TupleOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(TupleOp op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
if (op.val().empty()) return failure();
|
||||||
|
|
||||||
|
Value first_element = op.val().front();
|
||||||
|
auto first_element_op =
|
||||||
|
dyn_cast_or_null<GetTupleElementOp>(first_element.getDefiningOp());
|
||||||
|
if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value tuple_predecessor = first_element_op.getOperand();
|
||||||
|
if (tuple_predecessor.getType() != op.getType()) return failure();
|
||||||
|
|
||||||
|
for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
|
||||||
|
auto element_op = dyn_cast_or_null<GetTupleElementOp>(
|
||||||
|
element_and_idx.value().getDefiningOp());
|
||||||
|
if (!element_op ||
|
||||||
|
element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
|
||||||
|
element_op.getOperand() != tuple_predecessor)
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, tuple_predecessor);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
|
MLIRContext* context) {
|
||||||
|
results.insert<UnpackRepackSameTuple>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AllToAllOp
|
// AllToAllOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -561,3 +561,25 @@ func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
|
||||||
|
|
||||||
return %arg0 : tensor<i64>
|
return %arg0 : tensor<i64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: unpack_repack_same_tuple
|
||||||
|
// CHECK-SAME: ([[ARG0:%.*]]: tuple<tensor<i32>, !mhlo.token, tensor<f32>>)
|
||||||
|
func @unpack_repack_same_tuple(%arg0: tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tuple<tensor<i32>, !mhlo.token, tensor<f32>> {
|
||||||
|
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tensor<i32>
|
||||||
|
%1 = "mhlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> !mhlo.token
|
||||||
|
%2 = "mhlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tensor<f32>
|
||||||
|
%3 = "mhlo.tuple"(%0, %1, %2) : (tensor<i32>, !mhlo.token, tensor<f32>) -> tuple<tensor<i32>, !mhlo.token, tensor<f32>>
|
||||||
|
|
||||||
|
// CHECK: return [[ARG0]]
|
||||||
|
return %3 : tuple<tensor<i32>, !mhlo.token, tensor<f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: unpack_repack_same_tuple_single_element
|
||||||
|
// CHECK-SAME: ([[ARG0:%.*]]: tuple<tensor<i32>>)
|
||||||
|
func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tuple<tensor<i32>> {
|
||||||
|
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
|
||||||
|
%3 = "mhlo.tuple"(%0) : (tensor<i32>) -> tuple<tensor<i32>>
|
||||||
|
|
||||||
|
// CHECK: return [[ARG0]]
|
||||||
|
return %3 : tuple<tensor<i32>>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue