[MHLO:Linalg] Add support for lowering reshape of unsigned tensors
PiperOrigin-RevId: 373461627
This commit is contained in:
		
							parent
							
								
									a2c9b3c9d7
								
							
						
					
					
						commit
						d764806c1e
					
				|  | @ -782,6 +782,9 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> { | ||||||
|     if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) |     if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) | ||||||
|       return failure(); |       return failure(); | ||||||
| 
 | 
 | ||||||
|  |     result_type = this->typeConverter->convertType(result_type) | ||||||
|  |                       .template cast<ShapedType>(); | ||||||
|  | 
 | ||||||
|     // Compute the reassociation maps for the linalg operation.
 |     // Compute the reassociation maps for the linalg operation.
 | ||||||
|     ArrayRef<int64_t> src_shape = |     ArrayRef<int64_t> src_shape = | ||||||
|         (operand_type.getRank() > result_type.getRank() |         (operand_type.getRank() > result_type.getRank() | ||||||
|  |  | ||||||
|  | @ -482,6 +482,19 @@ func @reshape_0D_1D(%arg0: tensor<i32>) -> tensor<1xi32> { | ||||||
| 
 | 
 | ||||||
| // ----- | // ----- | ||||||
| 
 | 
 | ||||||
|  | func @reshape_0D_1D_unsigned(%arg0: tensor<ui32>) -> tensor<1xui32> { | ||||||
|  |   %0 = "mhlo.reshape"(%arg0) : (tensor<ui32>) -> tensor<1xui32> | ||||||
|  |   return %0 : tensor<1xui32> | ||||||
|  | } | ||||||
|  | // CHECK-LABEL: func @reshape_0D_1D_unsigned | ||||||
|  | // CHECK-SAME:    %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]] | ||||||
|  | // CHECK:         %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<ui32> to tensor<i32> | ||||||
|  | // CHECK:         %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor<i32> into tensor<1xi32> | ||||||
|  | // CHECK:         %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<1xi32> to tensor<1xui32> | ||||||
|  | // CHECK:         return %[[RET_UNSIGNED]] : tensor<1xui32> | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
| // CHECK-LABEL: func @reshape_1D_0D | // CHECK-LABEL: func @reshape_1D_0D | ||||||
| func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> { | func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> { | ||||||
|   %0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32> |   %0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32> | ||||||
|  | @ -491,6 +504,19 @@ func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> { | ||||||
| 
 | 
 | ||||||
| // ----- | // ----- | ||||||
| 
 | 
 | ||||||
|  | func @reshape_1D_0D_unsigned(%arg0: tensor<1xui32>) -> tensor<ui32> { | ||||||
|  |   %0 = "mhlo.reshape"(%arg0) : (tensor<1xui32>) -> tensor<ui32> | ||||||
|  |   return %0 : tensor<ui32> | ||||||
|  | } | ||||||
|  | // CHECK-LABEL: func @reshape_1D_0D_unsigned | ||||||
|  | // CHECK-SAME:    %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]] | ||||||
|  | // CHECK:         %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<1xui32> to tensor<1xi32> | ||||||
|  | // CHECK:         %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor<1xi32> into tensor<i32> | ||||||
|  | // CHECK:         %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<i32> to tensor<ui32> | ||||||
|  | // CHECK:         return %[[RET_UNSIGNED]] : tensor<ui32> | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
| // CHECK-LABEL: func @reshape_3D_2D | // CHECK-LABEL: func @reshape_3D_2D | ||||||
| func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { | func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { | ||||||
|   %0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> |   %0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue