[KERNEL_GEN] Add a canonicalization pattern to drop a redundant dynamic reshape.

PiperOrigin-RevId: 351141868
This commit is contained in:
Alexander Belyaev 2021-01-11 06:36:14 -08:00 committed by TensorFlow MLIR Team
parent 47848764a5
commit ecf1bf5132
2 changed files with 58 additions and 3 deletions

View File

@ -1277,13 +1277,54 @@ class DynamicReshapeOpNotActuallyDynamic
return success(); return success();
} }
}; };
// Canonicalizes
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
// %1 = same_operands_and_result_shape_op(%tensor)
// %2 = "mhlo.dynamic_reshape"(%1, %shape)
// ... uses of %2.
//
// into
//
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
// %1 = same_operands_and_result_shape_op(%tensor)
// ... uses of %1.
class DynamicReshapeOpSameShapeOpResult
: public OpRewritePattern<DynamicReshapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicReshapeOp op,
PatternRewriter& rewriter) const override {
Operation* def_op = op.operand().getDefiningOp();
if (!def_op || !def_op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
return failure();
}
Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
if (!input_def_op) {
return failure();
}
auto reshape = dyn_cast<DynamicReshapeOp>(*input_def_op);
if (reshape && reshape.output_shape() == op.output_shape()) {
rewriter.replaceOp(op, {def_op->getResult(0)});
return success();
}
return failure();
}
};
} // namespace } // namespace
void DynamicReshapeOp::getCanonicalizationPatterns( void DynamicReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicReshapeOpNotActuallyDynamic, // clang-format off
RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape, results.insert<
ShapeOfDynamicReshape>(context); DynamicReshapeOpNotActuallyDynamic,
DynamicReshapeOpSameShapeOpResult,
RemoveRedundantDynamicBroadcast,
RemoveRedundantDynamicReshape,
ShapeOfDynamicReshape
>(context);
// clang-format on
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1566,3 +1566,17 @@ func @permutation_broadcast_of_reshape(%arg: tensor<?xf32>,
} }
// CHECK: mhlo.dynamic_reshape // CHECK: mhlo.dynamic_reshape
// CHECK: mhlo.dynamic_broadcast_in_dim // CHECK: mhlo.dynamic_broadcast_in_dim
// CHECK-LABEL: @reshape_of_same_shape_op_result
func @reshape_of_same_shape_op_result(%arg: tensor<?xf32>,
%shape: tensor<2xindex>) -> tensor<?x?xf32> {
%0 = "mhlo.dynamic_reshape"(%arg, %shape)
: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
%1 = "mhlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "mhlo.dynamic_reshape"(%1, %shape)
: (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}
// CHECK: mhlo.dynamic_reshape
// CHECK-NEXT: mhlo.abs
// CHECK-NOT: mhlo.dynamic_reshape