Fold away shape.shape_of(mhlo.dynamic_reshape(inp, shape))
This specific pattern can be replaced with the shape passed to dynamic_reshape. This is implemented as a canonicalization on mhlo.dynamic_reshape to fit in the infrastructure of canonicalization. PiperOrigin-RevId: 342009365
This commit is contained in:
parent
7fc4985eae
commit
1dffa62fe9
|
@ -1268,7 +1268,8 @@ class DynamicReshapeOpNotActuallyDynamic
|
|||
|
||||
void DynamicReshapeOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<DynamicReshapeOpNotActuallyDynamic>(context);
|
||||
results.insert<DynamicReshapeOpNotActuallyDynamic, ShapeOfDynamicReshape>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -28,3 +28,6 @@ def DynamicBroadcastToOwnShape_2 : Pat<
|
|||
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
|
||||
(replaceWithValue $x)>;
|
||||
|
||||
def ShapeOfDynamicReshape : Pat<
|
||||
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
||||
(replaceWithValue $shape)>;
|
||||
|
|
|
@ -575,6 +575,16 @@ func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<
|
|||
return %0 : tensor<4x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_of_dynamic_reshape
|
||||
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
|
||||
func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> tensor<2xindex> {
|
||||
// CHECK: return [[ARG1]]
|
||||
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
%1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
|
||||
return %1 : tensor<2xindex>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: do_not_dce_while_with_outfeed
|
||||
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: mhlo.while
|
||||
|
|
Loading…
Reference in New Issue