Fix the shape of linalg.init_tensor in conv op lowering.

The output spatial dims are not as same as the input spatial dims. Only supports
static output spatial dims for now.

PiperOrigin-RevId: 359775479
This commit is contained in:
Hanhan Wang 2021-02-26 09:33:21 -08:00 committed by TensorFlow MLIR Team
parent c616963501
commit a8f99ee0f5
2 changed files with 35 additions and 42 deletions

View File

@ -1476,9 +1476,14 @@ struct NormalConvOpOnTensorsConversion
// The output shape is N spatial_dims F. // The output shape is N spatial_dims F.
SmallVector<Value, 8> dyn_sizes; SmallVector<Value, 8> dyn_sizes;
for (int64_t i = 0, e = rank - 1; i < e; ++i) { if (result_type.isDynamicDim(0)) {
if (!result_type.isDynamicDim(i)) continue; dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, 0));
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, i)); }
for (int64_t i = 1, e = rank - 1; i < e; ++i) {
if (result_type.isDynamicDim(i)) {
return rewriter.notifyMatchFailure(
op, "expected output spatial dims to be static shapes");
}
} }
if (result_type.isDynamicDim(rank - 1)) { if (result_type.isDynamicDim(rank - 1)) {
dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1)); dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));

View File

@ -1444,8 +1444,8 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor<f32>) -> tensor<18x12xf3
// ----- // -----
func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x8x?xf32>, %arg1: tensor<2x?x?xf32>)
-> tensor<?x?x?xf32> { -> tensor<?x7x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) { %0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64, batch_group_count = 1 : i64,
dimension_numbers = { dimension_numbers = {
@ -1463,31 +1463,29 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tenso
padding = dense<[[0], [0]]> : tensor<2x1xi64>, padding = dense<[[0], [0]]> : tensor<2x1xi64>,
rhs_dilation = dense<1> : tensor<1xi64>, rhs_dilation = dense<1> : tensor<1xi64>,
window_strides = dense<1> : tensor<1xi64> window_strides = dense<1> : tensor<1xi64>
} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> } : (tensor<?x8x?xf32>, tensor<2x?x?xf32>) -> tensor<?x7x?xf32>
return %0 : tensor<?x?x?xf32> return %0 : tensor<?x7x?xf32>
} }
// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf // CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index // CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32> // CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index // CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32> // CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_1d_input_nwc_filter_wcf // CHECK: linalg.conv_1d_input_nwc_filter_wcf
// CHECK-SAME: {dilations = dense<1> : tensor<1xi64> // CHECK-SAME: {dilations = dense<1> : tensor<1xi64>
// CHECK-SAME: strides = dense<1> : tensor<1xi64>} // CHECK-SAME: strides = dense<1> : tensor<1xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x8x?xf32>, tensor<2x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> // CHECK-SAME: outs(%[[FILL]] : tensor<?x7x?xf32>) -> tensor<?x7x?xf32>
// ----- // -----
func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3x2x?x?xf32>)
-> tensor<?x?x?x?xf32> { -> tensor<?x2x3x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) { %0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64, batch_group_count = 1 : i64,
dimension_numbers = { dimension_numbers = {
@ -1505,33 +1503,29 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?
padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>, padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>, rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64> window_strides = dense<1> : tensor<2xi64>
} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> } : (tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>) -> tensor<?x2x3x?xf32>
return %0 : tensor<?x?x?x?xf32> return %0 : tensor<?x2x3x?xf32>
} }
// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf // CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index // CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32> // CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[C3:.+]] = constant 3 : index // CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<?x?x?x?xf32> // CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64> // CHECK-SAME: {dilations = dense<1> : tensor<2xi64>
// CHECK-SAME: strides = dense<1> : tensor<2xi64>} // CHECK-SAME: strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> // CHECK-SAME: outs(%[[FILL]] : tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32>
// ----- // -----
func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>) func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x8x8x8x?xf32>, %arg1: tensor<2x2x2x?x?xf32>)
-> tensor<?x?x?x?x?xf32> { -> tensor<?x7x7x7x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) { %0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64, batch_group_count = 1 : i64,
dimension_numbers = { dimension_numbers = {
@ -1549,30 +1543,24 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tens
padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>, padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
rhs_dilation = dense<1> : tensor<3xi64>, rhs_dilation = dense<1> : tensor<3xi64>,
window_strides = dense<1> : tensor<3xi64> window_strides = dense<1> : tensor<3xi64>
} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> } : (tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>) -> tensor<?x7x7x7x?xf32>
return %0 : tensor<?x?x?x?x?xf32> return %0 : tensor<?x7x7x7x?xf32>
} }
// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf // CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index // CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?x?xf32> // CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C4:.+]] = constant 4 : index // CHECK: %[[C4:.+]] = constant 4 : index
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<?x?x?x?x?xf32> // CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf // CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64> // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>
// CHECK-SAME: strides = dense<1> : tensor<3xi64>} // CHECK-SAME: strides = dense<1> : tensor<3xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> // CHECK-SAME: outs(%[[FILL]] : tensor<?x7x7x7x?xf32>) -> tensor<?x7x7x7x?xf32>
// ----- // -----