Add a canonicalization pattern to remove redundant dynamic_reshapes.
PiperOrigin-RevId: 344517381
This commit is contained in:
parent
f183e95e69
commit
d14c63da54
|
@ -1268,8 +1268,8 @@ class DynamicReshapeOpNotActuallyDynamic
|
||||||
|
|
||||||
void DynamicReshapeOp::getCanonicalizationPatterns(
|
void DynamicReshapeOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<DynamicReshapeOpNotActuallyDynamic, ShapeOfDynamicReshape>(
|
results.insert<DynamicReshapeOpNotActuallyDynamic,
|
||||||
context);
|
RemoveRedundantDynamicReshape, ShapeOfDynamicReshape>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -28,3 +28,8 @@ def UnaryEinsumToEinsum : Pat<
|
||||||
(HLO_UnaryEinsumOp $operand, $equation),
|
(HLO_UnaryEinsumOp $operand, $equation),
|
||||||
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
|
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
|
||||||
$operand, (UnaryToBinaryEinsumEq $equation))>;
|
$operand, (UnaryToBinaryEinsumEq $equation))>;
|
||||||
|
|
||||||
|
// A dynamic reshape of a dynamic reshape is a dynamic reshape.
|
||||||
|
def RemoveRedundantDynamicReshape : Pat<
|
||||||
|
(HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2),
|
||||||
|
(HLO_DynamicReshapeOp $operand, $shape2)>;
|
||||||
|
|
|
@ -585,6 +585,20 @@ func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) ->
|
||||||
return %1 : tensor<2xindex>
|
return %1 : tensor<2xindex>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_reshape_of_dynamic_reshape
|
||||||
|
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
|
||||||
|
func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor<?xf16>, %shape: tensor<?xindex>) -> tensor<?xf16> {
|
||||||
|
// CHECK: [[RES:%[a-zA-Z0-9]+]] = "mhlo.dynamic_reshape"([[ARG0]], %{{[a-zA-Z0-9]+}}) : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
|
||||||
|
// CHECK: return [[RES]]
|
||||||
|
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf16>, tensor<?xindex>) -> tensor<*xf16>
|
||||||
|
%1 = shape.shape_of %0 : tensor<*xf16> -> tensor<?xindex>
|
||||||
|
%2 = shape.num_elements %1 : tensor<?xindex> -> index
|
||||||
|
%3 = tensor_from_elements %2 : tensor<1xindex>
|
||||||
|
%4 = "mhlo.dynamic_reshape"(%0, %3) : (tensor<*xf16>, tensor<1xindex>) -> tensor<?xf16>
|
||||||
|
return %4 : tensor<?xf16>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: do_not_dce_while_with_outfeed
|
// CHECK-LABEL: do_not_dce_while_with_outfeed
|
||||||
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
|
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
|
||||||
// CHECK: mhlo.while
|
// CHECK: mhlo.while
|
||||||
|
|
Loading…
Reference in New Issue