Restrict canonicalization to avoid changing type

Issue #47516

PiperOrigin-RevId: 363300979
This commit is contained in:
Jacques Pienaar 2021-03-16 16:53:10 -07:00 committed by TensorFlow MLIR Team
parent caae2525ef
commit a58e62590e
1 changed files with 5 additions and 4 deletions

View File

@ -36,12 +36,13 @@ def DynamicBroadcastToOwnShape_4 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $x, (Tensor_CastOp (Shape_ShapeOfOp $x)), $attr),
(Tensor_CastOp $x)>;
def ShapeOfDynamicReshape : Pat<
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
(replaceWithValue $shape)>;
def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
def ShapeOfDynamicReshape : Pat<
(Shape_ShapeOfOp:$op (HLO_DynamicReshapeOp $x, $shape)),
(replaceWithValue $shape),
[(HasSameType $shape, $op)]>;
def IdentityBroadcastReshape : Pat<
(HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
(replaceWithValue $input),