Add canonicalizer for Reshape(Broadcast(X)) pattern when it is an identity sequence

PiperOrigin-RevId: 343251257
This commit is contained in:
A. Unique TensorFlower 2020-11-19 02:32:08 -08:00 committed by TensorFlow MLIR Team
parent 61537008f4
commit 7f239c7ba2
4 changed files with 37 additions and 0 deletions

View File

@ -1046,6 +1046,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
let results = (outs HLO_StaticShapeTensor); let results = (outs HLO_StaticShapeTensor);
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1;
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
} }

View File

@ -1939,6 +1939,12 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
return {}; return {};
} }
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>(
context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Case Op // Case Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -31,3 +31,15 @@ def DynamicBroadcastToOwnShape_2 : Pat<
def ShapeOfDynamicReshape : Pat< def ShapeOfDynamicReshape : Pat<
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)), (Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
(replaceWithValue $shape)>; (replaceWithValue $shape)>;
def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
def IdentityBroadcastReshape : Pat<
(HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
(replaceWithValue $input),
[(HasSameType $input, $op)]>;
def IdentityBroadcastInDimReshape : Pat<
(HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)),
(replaceWithValue $input),
[(HasSameType $input, $op)]>;

View File

@ -1483,3 +1483,21 @@ func @pad_fold() -> tensor<4x5xi32> {
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1] // CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
// CHECK-SAME: ]> : tensor<4x5xi32> // CHECK-SAME: ]> : tensor<4x5xi32>
} }
// CHECK-LABEL: @identity_broadcast_reshape
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
%0 = "mhlo.broadcast"(%arg0) {
broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
return %1 : tensor<128xf32>
// CHECK: return %arg0 : tensor<128xf32>
}
// CHECK-LABEL: @identity_broadcast_in_dim_reshape
func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
%0 = "mhlo.broadcast_in_dim"(%arg0) {
broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32>
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
return %1 : tensor<128xf32>
// CHECK: return %arg0 : tensor<128xf32>
}