Canonicalize dynamic_broadcast_in_dim to own shape with rank narrowing on the shape to a corresponding tensor.cast.

PiperOrigin-RevId: 362028291
This commit is contained in:
Stephan Herhut 2021-03-10 05:43:10 -08:00 committed by TensorFlow MLIR Team
parent 507d9fb61d
commit cabd4d9a06
4 changed files with 34 additions and 1 deletions

1
BUILD
View File

@ -164,6 +164,7 @@ gentbl(
deps = [ deps = [
":hlo_ops_td_files", ":hlo_ops_td_files",
"@llvm-project//mlir:StdOpsTdFiles", "@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:TensorOpsTdFiles",
], ],
) )

View File

@ -920,7 +920,8 @@ class DynamicBroadcastInDimOpNotActuallyDynamic
void DynamicBroadcastInDimOp::getCanonicalizationPatterns( void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic, results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2>( DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2,
DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>(
context); context);
} }

View File

@ -16,6 +16,7 @@ limitations under the License.
// Canonicalization patterns for the MHLO dialect. // Canonicalization patterns for the MHLO dialect.
include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
// Canonicalization patterns. // Canonicalization patterns.
@ -27,6 +28,13 @@ def DynamicBroadcastToOwnShape_1 : Pat<
def DynamicBroadcastToOwnShape_2 : Pat< def DynamicBroadcastToOwnShape_2 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr), (HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
(replaceWithValue $x)>; (replaceWithValue $x)>;
def DynamicBroadcastToOwnShape_3 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $x,
(Tensor_CastOp (Shape_ToExtentTensorOp (Shape_ShapeOfOp $x))), $attr),
(Tensor_CastOp $x)>;
def DynamicBroadcastToOwnShape_4 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $x, (Tensor_CastOp (Shape_ShapeOfOp $x)), $attr),
(Tensor_CastOp $x)>;
def ShapeOfDynamicReshape : Pat< def ShapeOfDynamicReshape : Pat<
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)), (Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),

View File

@ -440,6 +440,29 @@ func @dynamic_broadcast_in_dim_to_same_shape_2(%arg0: tensor<?xf32>) -> tensor<?
return %2 : tensor<?xf32> return %2 : tensor<?xf32>
} }
// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_3
func @dynamic_broadcast_in_dim_to_same_shape_3(%arg0: tensor<*xf32>) -> tensor<?xf32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>
%0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
%2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<*xf32> to tensor<?xf32>
// CHECK: return %[[RES]] : tensor<?xf32>
return %2 : tensor<?xf32>
}
// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_4
func @dynamic_broadcast_in_dim_to_same_shape_4(%arg0: tensor<*xf32>) -> tensor<?xf32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>
%0 = shape.shape_of %arg0 : tensor<*xf32> -> !shape.shape
%1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<?xindex>
%2 = tensor.cast %1 : tensor<?xindex> to tensor<1xindex>
%3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<*xf32> to tensor<?xf32>
// CHECK: return %[[RES]] : tensor<?xf32>
return %3 : tensor<?xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d // CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d
func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> {
%cst = mhlo.constant dense<0.000000e+00> : tensor<f32> %cst = mhlo.constant dense<0.000000e+00> : tensor<f32>