diff --git a/lib/Dialect/mhlo/IR/hlo_patterns.td b/lib/Dialect/mhlo/IR/hlo_patterns.td index 73fca2d..58193a8 100644 --- a/lib/Dialect/mhlo/IR/hlo_patterns.td +++ b/lib/Dialect/mhlo/IR/hlo_patterns.td @@ -36,12 +36,13 @@ def DynamicBroadcastToOwnShape_4 : Pat< (HLO_DynamicBroadcastInDimOp:$op $x, (Tensor_CastOp (Shape_ShapeOfOp $x)), $attr), (Tensor_CastOp $x)>; -def ShapeOfDynamicReshape : Pat< - (Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)), - (replaceWithValue $shape)>; - def HasSameType : Constraint>; +def ShapeOfDynamicReshape : Pat< + (Shape_ShapeOfOp:$op (HLO_DynamicReshapeOp $x, $shape)), + (replaceWithValue $shape), + [(HasSameType $shape, $op)]>; + def IdentityBroadcastReshape : Pat< (HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)), (replaceWithValue $input),