Add canonicalizer for Reshape(Broadcast(X)) pattern when it is an identity sequence
PiperOrigin-RevId: 343251257
This commit is contained in:
parent
61537008f4
commit
7f239c7ba2
|
@ -1046,6 +1046,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
|
|||
|
||||
let results = (outs HLO_StaticShapeTensor);
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
|
|
@ -1939,6 +1939,12 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||
MLIRContext* context) {
|
||||
results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Case Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -31,3 +31,15 @@ def DynamicBroadcastToOwnShape_2 : Pat<
|
|||
def ShapeOfDynamicReshape : Pat<
|
||||
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
||||
(replaceWithValue $shape)>;
|
||||
|
||||
def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
|
||||
def IdentityBroadcastReshape : Pat<
|
||||
(HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
|
||||
(replaceWithValue $input),
|
||||
[(HasSameType $input, $op)]>;
|
||||
|
||||
def IdentityBroadcastInDimReshape : Pat<
|
||||
(HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)),
|
||||
(replaceWithValue $input),
|
||||
[(HasSameType $input, $op)]>;
|
||||
|
|
|
@ -1483,3 +1483,21 @@ func @pad_fold() -> tensor<4x5xi32> {
|
|||
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
|
||||
// CHECK-SAME: ]> : tensor<4x5xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @identity_broadcast_reshape
|
||||
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||
%0 = "mhlo.broadcast"(%arg0) {
|
||||
broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
|
||||
return %1 : tensor<128xf32>
|
||||
// CHECK: return %arg0 : tensor<128xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @identity_broadcast_in_dim_reshape
|
||||
func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {
|
||||
broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
|
||||
return %1 : tensor<128xf32>
|
||||
// CHECK: return %arg0 : tensor<128xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue