diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 998fdd6..b00973b 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1277,13 +1277,54 @@ class DynamicReshapeOpNotActuallyDynamic return success(); } }; + +// Canonicalizes +// %0 = "mhlo.dynamic_reshape"(%tensor, %shape) +// %1 = same_operands_and_result_shape_op(%tensor) +// %2 = "mhlo.dynamic_reshape"(%1, %shape) +// ... uses of %2. +// +// into +// +// %0 = "mhlo.dynamic_reshape"(%tensor, %shape) +// %1 = same_operands_and_result_shape_op(%tensor) +// ... uses of %1. +class DynamicReshapeOpSameShapeOpResult + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicReshapeOp op, + PatternRewriter& rewriter) const override { + Operation* def_op = op.operand().getDefiningOp(); + if (!def_op || !def_op->hasTrait()) { + return failure(); + } + Operation* input_def_op = def_op->getOperand(0).getDefiningOp(); + if (!input_def_op) { + return failure(); + } + auto reshape = dyn_cast(*input_def_op); + if (reshape && reshape.output_shape() == op.output_shape()) { + rewriter.replaceOp(op, {def_op->getResult(0)}); + return success(); + } + return failure(); + } +}; } // namespace void DynamicReshapeOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { - results.insert(context); + // clang-format off + results.insert< + DynamicReshapeOpNotActuallyDynamic, + DynamicReshapeOpSameShapeOpResult, + RemoveRedundantDynamicBroadcast, + RemoveRedundantDynamicReshape, + ShapeOfDynamicReshape + >(context); + // clang-format on } //===----------------------------------------------------------------------===// diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index c3e0143..1127d4c 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -1566,3 +1566,17 @@ func @permutation_broadcast_of_reshape(%arg: tensor, } // CHECK: mhlo.dynamic_reshape // CHECK: mhlo.dynamic_broadcast_in_dim + +// CHECK-LABEL: @reshape_of_same_shape_op_result +func @reshape_of_same_shape_op_result(%arg: tensor, + %shape: tensor<2xindex>) -> tensor { + %0 = "mhlo.dynamic_reshape"(%arg, %shape) + : (tensor, tensor<2xindex>) -> tensor + %1 = "mhlo.abs"(%0) : (tensor) -> tensor + %2 = "mhlo.dynamic_reshape"(%1, %shape) + : (tensor, tensor<2xindex>) -> tensor + return %2 : tensor +} +// CHECK: mhlo.dynamic_reshape +// CHECK-NEXT: mhlo.abs +// CHECK-NOT: mhlo.dynamic_reshape