Add canonicalization for unpacking and repacking the same tuple (e.g. tuple -> get_tuple_element -> tuple).

These unpacking and repacking of tuples may be generated when modifying tuple arguments or results.

PiperOrigin-RevId: 325162694
This commit is contained in:
Andy Ly 2020-08-05 21:36:20 -07:00 committed by TensorFlow MLIR Team
parent ad12e06ceb
commit c340367702
3 changed files with 63 additions and 0 deletions

View File

@ -671,6 +671,7 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
"OpBuilder &builder, OperationState &results, "
"ValueRange values">];
let hasCanonicalizer = 1;
}
def HLO_CompareOp: HLO_Op<"compare",

View File

@ -506,6 +506,46 @@ static LogicalResult Verify(TupleOp op) {
return success();
}
namespace {
// Pattern for unpacking and repacking the same tuple.
struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
using OpRewritePattern<TupleOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TupleOp op,
PatternRewriter& rewriter) const override {
if (op.val().empty()) return failure();
Value first_element = op.val().front();
auto first_element_op =
dyn_cast_or_null<GetTupleElementOp>(first_element.getDefiningOp());
if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
return failure();
Value tuple_predecessor = first_element_op.getOperand();
if (tuple_predecessor.getType() != op.getType()) return failure();
for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
auto element_op = dyn_cast_or_null<GetTupleElementOp>(
element_and_idx.value().getDefiningOp());
if (!element_op ||
element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
element_op.getOperand() != tuple_predecessor)
return failure();
}
rewriter.replaceOp(op, tuple_predecessor);
return success();
}
};
} // namespace
void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<UnpackRepackSameTuple>(context);
}
//===----------------------------------------------------------------------===//
// AllToAllOp
//===----------------------------------------------------------------------===//

View File

@ -561,3 +561,25 @@ func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
return %arg0 : tensor<i64>
}
// CHECK-LABEL: unpack_repack_same_tuple
// CHECK-SAME: ([[ARG0:%.*]]: tuple<tensor<i32>, !mhlo.token, tensor<f32>>)
func @unpack_repack_same_tuple(%arg0: tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tuple<tensor<i32>, !mhlo.token, tensor<f32>> {
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tensor<i32>
%1 = "mhlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> !mhlo.token
%2 = "mhlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tensor<f32>
%3 = "mhlo.tuple"(%0, %1, %2) : (tensor<i32>, !mhlo.token, tensor<f32>) -> tuple<tensor<i32>, !mhlo.token, tensor<f32>>
// CHECK: return [[ARG0]]
return %3 : tuple<tensor<i32>, !mhlo.token, tensor<f32>>
}
// CHECK-LABEL: unpack_repack_same_tuple_single_element
// CHECK-SAME: ([[ARG0:%.*]]: tuple<tensor<i32>>)
func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tuple<tensor<i32>> {
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
%3 = "mhlo.tuple"(%0) : (tensor<i32>) -> tuple<tensor<i32>>
// CHECK: return [[ARG0]]
return %3 : tuple<tensor<i32>>
}