[KERNEL_GEN] Add canonicalizaton pattern to drop a redundant broadcast op.
PiperOrigin-RevId: 350105790
This commit is contained in:
parent
2727ed4cf2
commit
095dc28e5c
|
@ -1282,7 +1282,8 @@ class DynamicReshapeOpNotActuallyDynamic
|
||||||
void DynamicReshapeOp::getCanonicalizationPatterns(
|
void DynamicReshapeOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<DynamicReshapeOpNotActuallyDynamic,
|
results.insert<DynamicReshapeOpNotActuallyDynamic,
|
||||||
RemoveRedundantDynamicReshape, ShapeOfDynamicReshape>(context);
|
RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape,
|
||||||
|
ShapeOfDynamicReshape>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -33,3 +33,10 @@ def UnaryEinsumToEinsum : Pat<
|
||||||
def RemoveRedundantDynamicReshape : Pat<
|
def RemoveRedundantDynamicReshape : Pat<
|
||||||
(HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2),
|
(HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2),
|
||||||
(HLO_DynamicReshapeOp $operand, $shape2)>;
|
(HLO_DynamicReshapeOp $operand, $shape2)>;
|
||||||
|
|
||||||
|
// A dynamic broadcast of a dynamic reshape with the same shape operand
|
||||||
|
// is a dynamic reshape.
|
||||||
|
def RemoveRedundantDynamicBroadcast : Pat<
|
||||||
|
(HLO_DynamicBroadcastInDimOp
|
||||||
|
(HLO_DynamicReshapeOp $operand, $shape), $shape, $dims),
|
||||||
|
(HLO_DynamicReshapeOp $operand, $shape)>;
|
||||||
|
|
|
@ -1540,3 +1540,14 @@ func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf3
|
||||||
return %1 : tensor<128xf32>
|
return %1 : tensor<128xf32>
|
||||||
// CHECK: return %arg0 : tensor<128xf32>
|
// CHECK: return %arg0 : tensor<128xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @broadcast_of_reshape
|
||||||
|
func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||||
|
%0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
|
||||||
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
|
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
return %1 : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
// CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape"
|
||||||
|
// CHECK: return [[RESHAPE]]
|
||||||
|
|
Loading…
Reference in New Issue