From cabd4d9a062b93af3f406445faf4db401da5c6b3 Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Wed, 10 Mar 2021 05:43:10 -0800 Subject: [PATCH] Canonicalize dynamic_broadcast_in_dim to own shape with rank narrowing on the shape to a corresponding tensor.cast. PiperOrigin-RevId: 362028291 --- BUILD | 1 + lib/Dialect/mhlo/IR/hlo_ops.cc | 3 ++- lib/Dialect/mhlo/IR/hlo_patterns.td | 8 ++++++++ tests/canonicalize.mlir | 23 +++++++++++++++++++++++ 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/BUILD b/BUILD index a4480c2..a030db2 100644 --- a/BUILD +++ b/BUILD @@ -164,6 +164,7 @@ gentbl( deps = [ ":hlo_ops_td_files", "@llvm-project//mlir:StdOpsTdFiles", + "@llvm-project//mlir:TensorOpsTdFiles", ], ) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 9132418..401dc16 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -920,7 +920,8 @@ class DynamicBroadcastInDimOpNotActuallyDynamic void DynamicBroadcastInDimOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { results.insert( + DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2, + DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>( context); } diff --git a/lib/Dialect/mhlo/IR/hlo_patterns.td b/lib/Dialect/mhlo/IR/hlo_patterns.td index 01564b8..73fca2d 100644 --- a/lib/Dialect/mhlo/IR/hlo_patterns.td +++ b/lib/Dialect/mhlo/IR/hlo_patterns.td @@ -16,6 +16,7 @@ limitations under the License. // Canonicalization patterns for the MHLO dialect. include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir/Dialect/Tensor/IR/TensorOps.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" // Canonicalization patterns. @@ -27,6 +28,13 @@ def DynamicBroadcastToOwnShape_1 : Pat< def DynamicBroadcastToOwnShape_2 : Pat< (HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr), (replaceWithValue $x)>; +def DynamicBroadcastToOwnShape_3 : Pat< + (HLO_DynamicBroadcastInDimOp:$op $x, + (Tensor_CastOp (Shape_ToExtentTensorOp (Shape_ShapeOfOp $x))), $attr), + (Tensor_CastOp $x)>; +def DynamicBroadcastToOwnShape_4 : Pat< + (HLO_DynamicBroadcastInDimOp:$op $x, (Tensor_CastOp (Shape_ShapeOfOp $x)), $attr), + (Tensor_CastOp $x)>; def ShapeOfDynamicReshape : Pat< (Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)), diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 5d48701..ec35f58 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -440,6 +440,29 @@ func @dynamic_broadcast_in_dim_to_same_shape_2(%arg0: tensor) -> tensor } +// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_3 +func @dynamic_broadcast_in_dim_to_same_shape_3(%arg0: tensor<*xf32>) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32> + %0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor + %1 = tensor.cast %0 : tensor to tensor<1xindex> + %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<*xf32> to tensor + // CHECK: return %[[RES]] : tensor + return %2 : tensor +} + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_4 +func @dynamic_broadcast_in_dim_to_same_shape_4(%arg0: tensor<*xf32>) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32> + %0 = shape.shape_of %arg0 : tensor<*xf32> -> !shape.shape + %1 = shape.to_extent_tensor %0 : !shape.shape -> tensor + %2 = tensor.cast %1 : tensor to tensor<1xindex> + %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<*xf32> to tensor + // CHECK: return %[[RES]] : tensor + return %3 : tensor +} + // CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { %cst = mhlo.constant dense<0.000000e+00> : tensor