[KERNEL_GEN] Add a canonicalization pattern to drop a redundant dynamic reshape.
PiperOrigin-RevId: 351141868
This commit is contained in:
parent
47848764a5
commit
ecf1bf5132
|
@ -1277,13 +1277,54 @@ class DynamicReshapeOpNotActuallyDynamic
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Canonicalizes
|
||||||
|
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
|
||||||
|
// %1 = same_operands_and_result_shape_op(%tensor)
|
||||||
|
// %2 = "mhlo.dynamic_reshape"(%1, %shape)
|
||||||
|
// ... uses of %2.
|
||||||
|
//
|
||||||
|
// into
|
||||||
|
//
|
||||||
|
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
|
||||||
|
// %1 = same_operands_and_result_shape_op(%tensor)
|
||||||
|
// ... uses of %1.
|
||||||
|
class DynamicReshapeOpSameShapeOpResult
|
||||||
|
: public OpRewritePattern<DynamicReshapeOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(DynamicReshapeOp op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
Operation* def_op = op.operand().getDefiningOp();
|
||||||
|
if (!def_op || !def_op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
|
||||||
|
if (!input_def_op) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto reshape = dyn_cast<DynamicReshapeOp>(*input_def_op);
|
||||||
|
if (reshape && reshape.output_shape() == op.output_shape()) {
|
||||||
|
rewriter.replaceOp(op, {def_op->getResult(0)});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void DynamicReshapeOp::getCanonicalizationPatterns(
|
void DynamicReshapeOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<DynamicReshapeOpNotActuallyDynamic,
|
// clang-format off
|
||||||
RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape,
|
results.insert<
|
||||||
ShapeOfDynamicReshape>(context);
|
DynamicReshapeOpNotActuallyDynamic,
|
||||||
|
DynamicReshapeOpSameShapeOpResult,
|
||||||
|
RemoveRedundantDynamicBroadcast,
|
||||||
|
RemoveRedundantDynamicReshape,
|
||||||
|
ShapeOfDynamicReshape
|
||||||
|
>(context);
|
||||||
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1566,3 +1566,17 @@ func @permutation_broadcast_of_reshape(%arg: tensor<?xf32>,
|
||||||
}
|
}
|
||||||
// CHECK: mhlo.dynamic_reshape
|
// CHECK: mhlo.dynamic_reshape
|
||||||
// CHECK: mhlo.dynamic_broadcast_in_dim
|
// CHECK: mhlo.dynamic_broadcast_in_dim
|
||||||
|
|
||||||
|
// CHECK-LABEL: @reshape_of_same_shape_op_result
|
||||||
|
func @reshape_of_same_shape_op_result(%arg: tensor<?xf32>,
|
||||||
|
%shape: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||||
|
%0 = "mhlo.dynamic_reshape"(%arg, %shape)
|
||||||
|
: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
%1 = "mhlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
%2 = "mhlo.dynamic_reshape"(%1, %shape)
|
||||||
|
: (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
return %2 : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
// CHECK: mhlo.dynamic_reshape
|
||||||
|
// CHECK-NEXT: mhlo.abs
|
||||||
|
// CHECK-NOT: mhlo.dynamic_reshape
|
||||||
|
|
Loading…
Reference in New Issue