Fix the shape of linalg.init_tensor in conv op lowering.
The output spatial dims are not as same as the input spatial dims. Only supports static output spatial dims for now. PiperOrigin-RevId: 359775479
This commit is contained in:
		
							parent
							
								
									c616963501
								
							
						
					
					
						commit
						a8f99ee0f5
					
				|  | @ -1476,9 +1476,14 @@ struct NormalConvOpOnTensorsConversion | |||
| 
 | ||||
|     // The output shape is N spatial_dims F.
 | ||||
|     SmallVector<Value, 8> dyn_sizes; | ||||
|     for (int64_t i = 0, e = rank - 1; i < e; ++i) { | ||||
|       if (!result_type.isDynamicDim(i)) continue; | ||||
|       dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, i)); | ||||
|     if (result_type.isDynamicDim(0)) { | ||||
|       dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, 0)); | ||||
|     } | ||||
|     for (int64_t i = 1, e = rank - 1; i < e; ++i) { | ||||
|       if (result_type.isDynamicDim(i)) { | ||||
|         return rewriter.notifyMatchFailure( | ||||
|             op, "expected output spatial dims to be static shapes"); | ||||
|       } | ||||
|     } | ||||
|     if (result_type.isDynamicDim(rank - 1)) { | ||||
|       dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1)); | ||||
|  |  | |||
|  | @ -1444,8 +1444,8 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor<f32>) -> tensor<18x12xf3 | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) | ||||
|   -> tensor<?x?x?xf32> { | ||||
| func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x8x?xf32>, %arg1: tensor<2x?x?xf32>) | ||||
|   -> tensor<?x7x?xf32> { | ||||
|   %0 = "mhlo.convolution"(%arg0, %arg1) { | ||||
|     batch_group_count = 1 : i64, | ||||
|     dimension_numbers = { | ||||
|  | @ -1463,31 +1463,29 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tenso | |||
|     padding = dense<[[0], [0]]> : tensor<2x1xi64>, | ||||
|     rhs_dilation = dense<1> : tensor<1xi64>, | ||||
|     window_strides = dense<1> : tensor<1xi64> | ||||
|   } : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | ||||
|   return %0 : tensor<?x?x?xf32> | ||||
|   } : (tensor<?x8x?xf32>, tensor<2x?x?xf32>) -> tensor<?x7x?xf32> | ||||
|   return %0 : tensor<?x7x?xf32> | ||||
| } | ||||
| // CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf | ||||
| // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]] | ||||
| // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]] | ||||
| // CHECK:         %[[C0:.+]] = constant 0 : index | ||||
| // CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32> | ||||
| // CHECK:         %[[C1:.+]] = constant 1 : index | ||||
| // CHECK:         %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32> | ||||
| // CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32> | ||||
| // CHECK:         %[[C2:.+]] = constant 2 : index | ||||
| // CHECK:         %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32> | ||||
| // CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]] | ||||
| // CHECK:         %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32> | ||||
| // CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]] | ||||
| // CHECK:         %[[ZERO:.+]] = constant 0.000000e+00 : f32 | ||||
| // CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) | ||||
| // CHECK:         linalg.conv_1d_input_nwc_filter_wcf | ||||
| // CHECK-SAME:      {dilations = dense<1> : tensor<1xi64> | ||||
| // CHECK-SAME:       strides = dense<1> : tensor<1xi64>} | ||||
| // CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) | ||||
| // CHECK-SAME:    outs(%[[FILL]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | ||||
| // CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<?x8x?xf32>, tensor<2x?x?xf32>) | ||||
| // CHECK-SAME:    outs(%[[FILL]] : tensor<?x7x?xf32>) -> tensor<?x7x?xf32> | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) | ||||
|   -> tensor<?x?x?x?xf32> { | ||||
| func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3x2x?x?xf32>) | ||||
|   -> tensor<?x2x3x?xf32> { | ||||
|   %0 = "mhlo.convolution"(%arg0, %arg1) { | ||||
|     batch_group_count = 1 : i64, | ||||
|     dimension_numbers = { | ||||
|  | @ -1505,33 +1503,29 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<? | |||
|     padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>, | ||||
|     rhs_dilation = dense<1> : tensor<2xi64>, | ||||
|     window_strides = dense<1> : tensor<2xi64> | ||||
|   } : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | ||||
|   return %0 : tensor<?x?x?x?xf32> | ||||
|   } : (tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>) -> tensor<?x2x3x?xf32> | ||||
|   return %0 : tensor<?x2x3x?xf32> | ||||
| } | ||||
| // CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf | ||||
| // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]] | ||||
| // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]] | ||||
| // CHECK:         %[[C0:.+]] = constant 0 : index | ||||
| // CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32> | ||||
| // CHECK:         %[[C1:.+]] = constant 1 : index | ||||
| // CHECK:         %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32> | ||||
| // CHECK:         %[[C2:.+]] = constant 2 : index | ||||
| // CHECK:         %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32> | ||||
| // CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32> | ||||
| // CHECK:         %[[C3:.+]] = constant 3 : index | ||||
| // CHECK:         %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<?x?x?x?xf32> | ||||
| // CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]] | ||||
| // CHECK:         %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32> | ||||
| // CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]] | ||||
| // CHECK:         %[[ZERO:.+]] = constant 0.000000e+00 : f32 | ||||
| // CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) | ||||
| // CHECK:         linalg.conv_2d_input_nhwc_filter_hwcf | ||||
| // CHECK-SAME:      {dilations = dense<1> : tensor<2xi64> | ||||
| // CHECK-SAME:       strides = dense<1> : tensor<2xi64>} | ||||
| // CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) | ||||
| // CHECK-SAME:    outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>) | ||||
| // CHECK-SAME:    outs(%[[FILL]] : tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32> | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>) | ||||
|   -> tensor<?x?x?x?x?xf32> { | ||||
| func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x8x8x8x?xf32>, %arg1: tensor<2x2x2x?x?xf32>) | ||||
|   -> tensor<?x7x7x7x?xf32> { | ||||
|   %0 = "mhlo.convolution"(%arg0, %arg1) { | ||||
|     batch_group_count = 1 : i64, | ||||
|     dimension_numbers = { | ||||
|  | @ -1549,30 +1543,24 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tens | |||
|     padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>, | ||||
|     rhs_dilation = dense<1> : tensor<3xi64>, | ||||
|     window_strides = dense<1> : tensor<3xi64> | ||||
|   } : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | ||||
|   return %0 : tensor<?x?x?x?x?xf32> | ||||
|   } : (tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>) -> tensor<?x7x7x7x?xf32> | ||||
|   return %0 : tensor<?x7x7x7x?xf32> | ||||
| } | ||||
| // CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf | ||||
| // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]] | ||||
| // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]] | ||||
| // CHECK:         %[[C0:.+]] = constant 0 : index | ||||
| // CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?x?xf32> | ||||
| // CHECK:         %[[C1:.+]] = constant 1 : index | ||||
| // CHECK:         %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?x?xf32> | ||||
| // CHECK:         %[[C2:.+]] = constant 2 : index | ||||
| // CHECK:         %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?x?xf32> | ||||
| // CHECK:         %[[C3:.+]] = constant 3 : index | ||||
| // CHECK:         %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?x?xf32> | ||||
| // CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32> | ||||
| // CHECK:         %[[C4:.+]] = constant 4 : index | ||||
| // CHECK:         %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<?x?x?x?x?xf32> | ||||
| // CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]] | ||||
| // CHECK:         %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32> | ||||
| // CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]] | ||||
| // CHECK:         %[[ZERO:.+]] = constant 0.000000e+00 : f32 | ||||
| // CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) | ||||
| // CHECK:         linalg.conv_3d_input_ndhwc_filter_dhwcf | ||||
| // CHECK-SAME:      {dilations = dense<1> : tensor<3xi64> | ||||
| // CHECK-SAME:       strides = dense<1> : tensor<3xi64>} | ||||
| // CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) | ||||
| // CHECK-SAME:    outs(%[[FILL]] : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>) | ||||
| // CHECK-SAME:    outs(%[[FILL]] : tensor<?x7x7x7x?xf32>) -> tensor<?x7x7x7x?xf32> | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue