Add canonicalizer for Reshape(Broadcast(X)) pattern when it is an identity sequence
PiperOrigin-RevId: 343251257
This commit is contained in:
parent
61537008f4
commit
7f239c7ba2
|
@ -1046,6 +1046,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
|
||||||
|
|
||||||
let results = (outs HLO_StaticShapeTensor);
|
let results = (outs HLO_StaticShapeTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1939,6 +1939,12 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
|
MLIRContext* context) {
|
||||||
|
results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>(
|
||||||
|
context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Case Op
|
// Case Op
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -31,3 +31,15 @@ def DynamicBroadcastToOwnShape_2 : Pat<
|
||||||
def ShapeOfDynamicReshape : Pat<
|
def ShapeOfDynamicReshape : Pat<
|
||||||
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
||||||
(replaceWithValue $shape)>;
|
(replaceWithValue $shape)>;
|
||||||
|
|
||||||
|
def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||||
|
|
||||||
|
def IdentityBroadcastReshape : Pat<
|
||||||
|
(HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
|
||||||
|
(replaceWithValue $input),
|
||||||
|
[(HasSameType $input, $op)]>;
|
||||||
|
|
||||||
|
def IdentityBroadcastInDimReshape : Pat<
|
||||||
|
(HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)),
|
||||||
|
(replaceWithValue $input),
|
||||||
|
[(HasSameType $input, $op)]>;
|
||||||
|
|
|
@ -1483,3 +1483,21 @@ func @pad_fold() -> tensor<4x5xi32> {
|
||||||
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
|
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
|
||||||
// CHECK-SAME: ]> : tensor<4x5xi32>
|
// CHECK-SAME: ]> : tensor<4x5xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @identity_broadcast_reshape
|
||||||
|
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||||
|
%0 = "mhlo.broadcast"(%arg0) {
|
||||||
|
broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>
|
||||||
|
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
|
||||||
|
return %1 : tensor<128xf32>
|
||||||
|
// CHECK: return %arg0 : tensor<128xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @identity_broadcast_in_dim_reshape
|
||||||
|
func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||||
|
%0 = "mhlo.broadcast_in_dim"(%arg0) {
|
||||||
|
broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32>
|
||||||
|
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
|
||||||
|
return %1 : tensor<128xf32>
|
||||||
|
// CHECK: return %arg0 : tensor<128xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue