From 095dc28e5c661c44cca96c320b6776093b511d51 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 5 Jan 2021 02:59:52 -0800 Subject: [PATCH] [KERNEL_GEN] Add canonicalizaton pattern to drop a redundant broadcast op. PiperOrigin-RevId: 350105790 --- lib/Dialect/mhlo/IR/hlo_ops.cc | 3 ++- lib/Dialect/mhlo/IR/mhlo_canonicalize.td | 7 +++++++ tests/canonicalize.mlir | 11 +++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index cec1ad7..998fdd6 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1282,7 +1282,8 @@ class DynamicReshapeOpNotActuallyDynamic void DynamicReshapeOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { results.insert(context); + RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape, + ShapeOfDynamicReshape>(context); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/IR/mhlo_canonicalize.td b/lib/Dialect/mhlo/IR/mhlo_canonicalize.td index e1b3dc6..a4c480b 100644 --- a/lib/Dialect/mhlo/IR/mhlo_canonicalize.td +++ b/lib/Dialect/mhlo/IR/mhlo_canonicalize.td @@ -33,3 +33,10 @@ def UnaryEinsumToEinsum : Pat< def RemoveRedundantDynamicReshape : Pat< (HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2), (HLO_DynamicReshapeOp $operand, $shape2)>; + +// A dynamic broadcast of a dynamic reshape with the same shape operand +// is a dynamic reshape. +def RemoveRedundantDynamicBroadcast : Pat< + (HLO_DynamicBroadcastInDimOp + (HLO_DynamicReshapeOp $operand, $shape), $shape, $dims), + (HLO_DynamicReshapeOp $operand, $shape)>; diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 8e17895..1b6ff8b 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -1540,3 +1540,14 @@ func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf3 return %1 : tensor<128xf32> // CHECK: return %arg0 : tensor<128xf32> } + +// CHECK-LABEL: @broadcast_of_reshape +func @broadcast_of_reshape(%arg: tensor, %shape: tensor<2xindex>) -> tensor { + %0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor, tensor<2xindex>) -> tensor + %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> + } : (tensor, tensor<2xindex>) -> tensor + return %1 : tensor +} +// CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape" +// CHECK: return [[RESHAPE]]