Canonicalize dynamic_broadcast_in_dim to own shape with rank narrowing on the shape to a corresponding tensor.cast.
PiperOrigin-RevId: 362028291
This commit is contained in:
parent
507d9fb61d
commit
cabd4d9a06
1
BUILD
1
BUILD
|
@ -164,6 +164,7 @@ gentbl(
|
||||||
deps = [
|
deps = [
|
||||||
":hlo_ops_td_files",
|
":hlo_ops_td_files",
|
||||||
"@llvm-project//mlir:StdOpsTdFiles",
|
"@llvm-project//mlir:StdOpsTdFiles",
|
||||||
|
"@llvm-project//mlir:TensorOpsTdFiles",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -920,7 +920,8 @@ class DynamicBroadcastInDimOpNotActuallyDynamic
|
||||||
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
|
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
|
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
|
||||||
DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2>(
|
DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2,
|
||||||
|
DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>(
|
||||||
context);
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
// Canonicalization patterns for the MHLO dialect.
|
// Canonicalization patterns for the MHLO dialect.
|
||||||
|
|
||||||
include "mlir/Dialect/Shape/IR/ShapeOps.td"
|
include "mlir/Dialect/Shape/IR/ShapeOps.td"
|
||||||
|
include "mlir/Dialect/Tensor/IR/TensorOps.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||||
|
|
||||||
// Canonicalization patterns.
|
// Canonicalization patterns.
|
||||||
|
@ -27,6 +28,13 @@ def DynamicBroadcastToOwnShape_1 : Pat<
|
||||||
def DynamicBroadcastToOwnShape_2 : Pat<
|
def DynamicBroadcastToOwnShape_2 : Pat<
|
||||||
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
|
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
|
||||||
(replaceWithValue $x)>;
|
(replaceWithValue $x)>;
|
||||||
|
def DynamicBroadcastToOwnShape_3 : Pat<
|
||||||
|
(HLO_DynamicBroadcastInDimOp:$op $x,
|
||||||
|
(Tensor_CastOp (Shape_ToExtentTensorOp (Shape_ShapeOfOp $x))), $attr),
|
||||||
|
(Tensor_CastOp $x)>;
|
||||||
|
def DynamicBroadcastToOwnShape_4 : Pat<
|
||||||
|
(HLO_DynamicBroadcastInDimOp:$op $x, (Tensor_CastOp (Shape_ShapeOfOp $x)), $attr),
|
||||||
|
(Tensor_CastOp $x)>;
|
||||||
|
|
||||||
def ShapeOfDynamicReshape : Pat<
|
def ShapeOfDynamicReshape : Pat<
|
||||||
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
||||||
|
|
|
@ -440,6 +440,29 @@ func @dynamic_broadcast_in_dim_to_same_shape_2(%arg0: tensor<?xf32>) -> tensor<?
|
||||||
return %2 : tensor<?xf32>
|
return %2 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_3
|
||||||
|
func @dynamic_broadcast_in_dim_to_same_shape_3(%arg0: tensor<*xf32>) -> tensor<?xf32> {
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>
|
||||||
|
%0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor<?xindex>
|
||||||
|
%1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
|
||||||
|
%2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<*xf32> to tensor<?xf32>
|
||||||
|
// CHECK: return %[[RES]] : tensor<?xf32>
|
||||||
|
return %2 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_4
|
||||||
|
func @dynamic_broadcast_in_dim_to_same_shape_4(%arg0: tensor<*xf32>) -> tensor<?xf32> {
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>
|
||||||
|
%0 = shape.shape_of %arg0 : tensor<*xf32> -> !shape.shape
|
||||||
|
%1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<?xindex>
|
||||||
|
%2 = tensor.cast %1 : tensor<?xindex> to tensor<1xindex>
|
||||||
|
%3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<*xf32> to tensor<?xf32>
|
||||||
|
// CHECK: return %[[RES]] : tensor<?xf32>
|
||||||
|
return %3 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d
|
// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d
|
||||||
func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> {
|
func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> {
|
||||||
%cst = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
%cst = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
|
Loading…
Reference in New Issue