From c340367702857f50d88abd1ee749830935863166 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Wed, 5 Aug 2020 21:36:20 -0700 Subject: [PATCH] Add canonicalization for unpacking and repacking the same tuple (e.g. tuple -> get_tuple_element -> tuple). These unpacking and repacking of tuples may be generated when modifying tuple arguments or results. PiperOrigin-RevId: 325162694 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 1 + lib/Dialect/mhlo/IR/hlo_ops.cc | 40 +++++++++++++++++++++ tests/canonicalize.mlir | 22 ++++++++++++ 3 files changed, 63 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index e83bf87..4c09c20 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -671,6 +671,7 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { "OpBuilder &builder, OperationState &results, " "ValueRange values">]; + let hasCanonicalizer = 1; } def HLO_CompareOp: HLO_Op<"compare", diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 69b0100..de3f950 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -506,6 +506,46 @@ static LogicalResult Verify(TupleOp op) { return success(); } +namespace { + +// Pattern for unpacking and repacking the same tuple. +struct UnpackRepackSameTuple : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TupleOp op, + PatternRewriter& rewriter) const override { + if (op.val().empty()) return failure(); + + Value first_element = op.val().front(); + auto first_element_op = + dyn_cast_or_null(first_element.getDefiningOp()); + if (!first_element_op || first_element_op.indexAttr().getInt() != 0) + return failure(); + + Value tuple_predecessor = first_element_op.getOperand(); + if (tuple_predecessor.getType() != op.getType()) return failure(); + + for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) { + auto element_op = dyn_cast_or_null( + element_and_idx.value().getDefiningOp()); + if (!element_op || + element_op.indexAttr().getInt() != element_and_idx.index() + 1 || + element_op.getOperand() != tuple_predecessor) + return failure(); + } + + rewriter.replaceOp(op, tuple_predecessor); + return success(); + } +}; + +} // namespace + +void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // AllToAllOp //===----------------------------------------------------------------------===// diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index f0fe522..e793e21 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -561,3 +561,25 @@ func @dce_while_without_side_effect(%arg0: tensor) -> tensor { return %arg0 : tensor } + +// CHECK-LABEL: unpack_repack_same_tuple +// CHECK-SAME: ([[ARG0:%.*]]: tuple, !mhlo.token, tensor>) +func @unpack_repack_same_tuple(%arg0: tuple, !mhlo.token, tensor>) -> tuple, !mhlo.token, tensor> { + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, !mhlo.token, tensor>) -> tensor + %1 = "mhlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, !mhlo.token, tensor>) -> !mhlo.token + %2 = "mhlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, !mhlo.token, tensor>) -> tensor + %3 = "mhlo.tuple"(%0, %1, %2) : (tensor, !mhlo.token, tensor) -> tuple, !mhlo.token, tensor> + + // CHECK: return [[ARG0]] + return %3 : tuple, !mhlo.token, tensor> +} + +// CHECK-LABEL: unpack_repack_same_tuple_single_element +// CHECK-SAME: ([[ARG0:%.*]]: tuple>) +func @unpack_repack_same_tuple_single_element(%arg0: tuple>) -> tuple> { + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple>) -> tensor + %3 = "mhlo.tuple"(%0) : (tensor) -> tuple> + + // CHECK: return [[ARG0]] + return %3 : tuple> +}