Fix the shape of linalg.init_tensor in conv op lowering.
The output spatial dims are not as same as the input spatial dims. Only supports static output spatial dims for now. PiperOrigin-RevId: 359775479
This commit is contained in:
parent
c616963501
commit
a8f99ee0f5
|
@ -1476,9 +1476,14 @@ struct NormalConvOpOnTensorsConversion
|
|||
|
||||
// The output shape is N spatial_dims F.
|
||||
SmallVector<Value, 8> dyn_sizes;
|
||||
for (int64_t i = 0, e = rank - 1; i < e; ++i) {
|
||||
if (!result_type.isDynamicDim(i)) continue;
|
||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, i));
|
||||
if (result_type.isDynamicDim(0)) {
|
||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, 0));
|
||||
}
|
||||
for (int64_t i = 1, e = rank - 1; i < e; ++i) {
|
||||
if (result_type.isDynamicDim(i)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected output spatial dims to be static shapes");
|
||||
}
|
||||
}
|
||||
if (result_type.isDynamicDim(rank - 1)) {
|
||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
|
||||
|
|
|
@ -1444,8 +1444,8 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor<f32>) -> tensor<18x12xf3
|
|||
|
||||
// -----
|
||||
|
||||
func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
|
||||
-> tensor<?x?x?xf32> {
|
||||
func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x8x?xf32>, %arg1: tensor<2x?x?xf32>)
|
||||
-> tensor<?x7x?xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
|
@ -1463,31 +1463,29 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tenso
|
|||
padding = dense<[[0], [0]]> : tensor<2x1xi64>,
|
||||
rhs_dilation = dense<1> : tensor<1xi64>,
|
||||
window_strides = dense<1> : tensor<1xi64>
|
||||
} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
} : (tensor<?x8x?xf32>, tensor<2x?x?xf32>) -> tensor<?x7x?xf32>
|
||||
return %0 : tensor<?x7x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
|
||||
// CHECK: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]]
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
// CHECK: linalg.conv_1d_input_nwc_filter_wcf
|
||||
// CHECK-SAME: {dilations = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<1xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x8x?xf32>, tensor<2x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x7x?xf32>) -> tensor<?x7x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>)
|
||||
-> tensor<?x?x?x?xf32> {
|
||||
func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3x2x?x?xf32>)
|
||||
-> tensor<?x2x3x?xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
|
@ -1505,33 +1503,29 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?
|
|||
padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?xf32>
|
||||
} : (tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>) -> tensor<?x2x3x?xf32>
|
||||
return %0 : tensor<?x2x3x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
|
||||
// CHECK: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
|
||||
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>)
|
||||
-> tensor<?x?x?x?x?xf32> {
|
||||
func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x8x8x8x?xf32>, %arg1: tensor<2x2x2x?x?xf32>)
|
||||
-> tensor<?x7x7x7x?xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
|
@ -1549,30 +1543,24 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tens
|
|||
padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
|
||||
rhs_dilation = dense<1> : tensor<3xi64>,
|
||||
window_strides = dense<1> : tensor<3xi64>
|
||||
} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?xf32>
|
||||
} : (tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>) -> tensor<?x7x7x7x?xf32>
|
||||
return %0 : tensor<?x7x7x7x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
|
||||
// CHECK: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]]
|
||||
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
// CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf
|
||||
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<3xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x7x7x7x?xf32>) -> tensor<?x7x7x7x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
|
|
Loading…
Reference in New Issue