[KERNEL_GEN] Add a canonicalization pattern to drop a redundant dynamic reshape.

PiperOrigin-RevId: 351141868
This commit is contained in:
Alexander Belyaev 2021-01-11 06:36:14 -08:00 committed by TensorFlow MLIR Team
parent 47848764a5
commit ecf1bf5132
2 changed files with 58 additions and 3 deletions

View File

@ -1277,13 +1277,54 @@ class DynamicReshapeOpNotActuallyDynamic
return success();
}
};
// Canonicalizes
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
// %1 = same_operands_and_result_shape_op(%tensor)
// %2 = "mhlo.dynamic_reshape"(%1, %shape)
// ... uses of %2.
//
// into
//
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
// %1 = same_operands_and_result_shape_op(%tensor)
// ... uses of %1.
class DynamicReshapeOpSameShapeOpResult
: public OpRewritePattern<DynamicReshapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicReshapeOp op,
PatternRewriter& rewriter) const override {
Operation* def_op = op.operand().getDefiningOp();
if (!def_op || !def_op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
return failure();
}
Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
if (!input_def_op) {
return failure();
}
auto reshape = dyn_cast<DynamicReshapeOp>(*input_def_op);
if (reshape && reshape.output_shape() == op.output_shape()) {
rewriter.replaceOp(op, {def_op->getResult(0)});
return success();
}
return failure();
}
};
} // namespace
void DynamicReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicReshapeOpNotActuallyDynamic,
RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape,
ShapeOfDynamicReshape>(context);
// clang-format off
results.insert<
DynamicReshapeOpNotActuallyDynamic,
DynamicReshapeOpSameShapeOpResult,
RemoveRedundantDynamicBroadcast,
RemoveRedundantDynamicReshape,
ShapeOfDynamicReshape
>(context);
// clang-format on
}
//===----------------------------------------------------------------------===//

View File

@ -1566,3 +1566,17 @@ func @permutation_broadcast_of_reshape(%arg: tensor<?xf32>,
}
// CHECK: mhlo.dynamic_reshape
// CHECK: mhlo.dynamic_broadcast_in_dim
// CHECK-LABEL: @reshape_of_same_shape_op_result
func @reshape_of_same_shape_op_result(%arg: tensor<?xf32>,
%shape: tensor<2xindex>) -> tensor<?x?xf32> {
%0 = "mhlo.dynamic_reshape"(%arg, %shape)
: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
%1 = "mhlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "mhlo.dynamic_reshape"(%1, %shape)
: (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}
// CHECK: mhlo.dynamic_reshape
// CHECK-NEXT: mhlo.abs
// CHECK-NOT: mhlo.dynamic_reshape