[KERNEL_GEN] Add canonicalizaton pattern to drop a redundant broadcast op.

PiperOrigin-RevId: 350105790
This commit is contained in:
Alexander Belyaev 2021-01-05 02:59:52 -08:00 committed by TensorFlow MLIR Team
parent 2727ed4cf2
commit 095dc28e5c
3 changed files with 20 additions and 1 deletions

View File

@ -1282,7 +1282,8 @@ class DynamicReshapeOpNotActuallyDynamic
void DynamicReshapeOp::getCanonicalizationPatterns( void DynamicReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicReshapeOpNotActuallyDynamic, results.insert<DynamicReshapeOpNotActuallyDynamic,
RemoveRedundantDynamicReshape, ShapeOfDynamicReshape>(context); RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape,
ShapeOfDynamicReshape>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -33,3 +33,10 @@ def UnaryEinsumToEinsum : Pat<
def RemoveRedundantDynamicReshape : Pat< def RemoveRedundantDynamicReshape : Pat<
(HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2), (HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2),
(HLO_DynamicReshapeOp $operand, $shape2)>; (HLO_DynamicReshapeOp $operand, $shape2)>;
// A dynamic broadcast of a dynamic reshape with the same shape operand
// is a dynamic reshape.
def RemoveRedundantDynamicBroadcast : Pat<
(HLO_DynamicBroadcastInDimOp
(HLO_DynamicReshapeOp $operand, $shape), $shape, $dims),
(HLO_DynamicReshapeOp $operand, $shape)>;

View File

@ -1540,3 +1540,14 @@ func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf3
return %1 : tensor<128xf32> return %1 : tensor<128xf32>
// CHECK: return %arg0 : tensor<128xf32> // CHECK: return %arg0 : tensor<128xf32>
} }
// CHECK-LABEL: @broadcast_of_reshape
func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?x?xf32> {
%0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape"
// CHECK: return [[RESHAPE]]