Fold away shape.shape_of(mhlo.dynamic_reshape(inp, shape))
This specific pattern can be replaced with the shape passed to dynamic_reshape. This is implemented as a canonicalization on mhlo.dynamic_reshape to fit in the infrastructure of canonicalization. PiperOrigin-RevId: 342009365
This commit is contained in:
parent
7fc4985eae
commit
1dffa62fe9
|
@ -1268,7 +1268,8 @@ class DynamicReshapeOpNotActuallyDynamic
|
||||||
|
|
||||||
void DynamicReshapeOp::getCanonicalizationPatterns(
|
void DynamicReshapeOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<DynamicReshapeOpNotActuallyDynamic>(context);
|
results.insert<DynamicReshapeOpNotActuallyDynamic, ShapeOfDynamicReshape>(
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -28,3 +28,6 @@ def DynamicBroadcastToOwnShape_2 : Pat<
|
||||||
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
|
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
|
||||||
(replaceWithValue $x)>;
|
(replaceWithValue $x)>;
|
||||||
|
|
||||||
|
def ShapeOfDynamicReshape : Pat<
|
||||||
|
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
||||||
|
(replaceWithValue $shape)>;
|
||||||
|
|
|
@ -575,6 +575,16 @@ func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<
|
||||||
return %0 : tensor<4x1xf32>
|
return %0 : tensor<4x1xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @shape_of_dynamic_reshape
|
||||||
|
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
|
||||||
|
func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> tensor<2xindex> {
|
||||||
|
// CHECK: return [[ARG1]]
|
||||||
|
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
%1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
|
||||||
|
return %1 : tensor<2xindex>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: do_not_dce_while_with_outfeed
|
// CHECK-LABEL: do_not_dce_while_with_outfeed
|
||||||
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
|
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
|
||||||
// CHECK: mhlo.while
|
// CHECK: mhlo.while
|
||||||
|
|
Loading…
Reference in New Issue