[KERNEL_GEN] Add a canonicalization pattern to drop a redundant dynamic reshape.
PiperOrigin-RevId: 351141868
This commit is contained in:
parent
47848764a5
commit
ecf1bf5132
|
@ -1277,13 +1277,54 @@ class DynamicReshapeOpNotActuallyDynamic
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Canonicalizes
|
||||
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
|
||||
// %1 = same_operands_and_result_shape_op(%tensor)
|
||||
// %2 = "mhlo.dynamic_reshape"(%1, %shape)
|
||||
// ... uses of %2.
|
||||
//
|
||||
// into
|
||||
//
|
||||
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
|
||||
// %1 = same_operands_and_result_shape_op(%tensor)
|
||||
// ... uses of %1.
|
||||
class DynamicReshapeOpSameShapeOpResult
|
||||
: public OpRewritePattern<DynamicReshapeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DynamicReshapeOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Operation* def_op = op.operand().getDefiningOp();
|
||||
if (!def_op || !def_op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
|
||||
return failure();
|
||||
}
|
||||
Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
|
||||
if (!input_def_op) {
|
||||
return failure();
|
||||
}
|
||||
auto reshape = dyn_cast<DynamicReshapeOp>(*input_def_op);
|
||||
if (reshape && reshape.output_shape() == op.output_shape()) {
|
||||
rewriter.replaceOp(op, {def_op->getResult(0)});
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void DynamicReshapeOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<DynamicReshapeOpNotActuallyDynamic,
|
||||
RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape,
|
||||
ShapeOfDynamicReshape>(context);
|
||||
// clang-format off
|
||||
results.insert<
|
||||
DynamicReshapeOpNotActuallyDynamic,
|
||||
DynamicReshapeOpSameShapeOpResult,
|
||||
RemoveRedundantDynamicBroadcast,
|
||||
RemoveRedundantDynamicReshape,
|
||||
ShapeOfDynamicReshape
|
||||
>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1566,3 +1566,17 @@ func @permutation_broadcast_of_reshape(%arg: tensor<?xf32>,
|
|||
}
|
||||
// CHECK: mhlo.dynamic_reshape
|
||||
// CHECK: mhlo.dynamic_broadcast_in_dim
|
||||
|
||||
// CHECK-LABEL: @reshape_of_same_shape_op_result
|
||||
func @reshape_of_same_shape_op_result(%arg: tensor<?xf32>,
|
||||
%shape: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||
%0 = "mhlo.dynamic_reshape"(%arg, %shape)
|
||||
: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "mhlo.dynamic_reshape"(%1, %shape)
|
||||
: (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
return %2 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK: mhlo.dynamic_reshape
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NOT: mhlo.dynamic_reshape
|
||||
|
|
Loading…
Reference in New Issue