[MHLO:Linalg] Add support for lowering reshape of unsigned tensors

PiperOrigin-RevId: 373461627
This commit is contained in:
Hanhan Wang 2021-05-12 15:13:20 -07:00 committed by TensorFlow MLIR Team
parent a2c9b3c9d7
commit d764806c1e
2 changed files with 29 additions and 0 deletions

View File

@ -782,6 +782,9 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
return failure(); return failure();
result_type = this->typeConverter->convertType(result_type)
.template cast<ShapedType>();
// Compute the reassociation maps for the linalg operation. // Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> src_shape = ArrayRef<int64_t> src_shape =
(operand_type.getRank() > result_type.getRank() (operand_type.getRank() > result_type.getRank()

View File

@ -482,6 +482,19 @@ func @reshape_0D_1D(%arg0: tensor<i32>) -> tensor<1xi32> {
// ----- // -----
func @reshape_0D_1D_unsigned(%arg0: tensor<ui32>) -> tensor<1xui32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<ui32>) -> tensor<1xui32>
return %0 : tensor<1xui32>
}
// CHECK-LABEL: func @reshape_0D_1D_unsigned
// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]]
// CHECK: %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<ui32> to tensor<i32>
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor<i32> into tensor<1xi32>
// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<1xi32> to tensor<1xui32>
// CHECK: return %[[RET_UNSIGNED]] : tensor<1xui32>
// -----
// CHECK-LABEL: func @reshape_1D_0D // CHECK-LABEL: func @reshape_1D_0D
func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> { func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32> %0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32>
@ -491,6 +504,19 @@ func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> {
// ----- // -----
func @reshape_1D_0D_unsigned(%arg0: tensor<1xui32>) -> tensor<ui32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<1xui32>) -> tensor<ui32>
return %0 : tensor<ui32>
}
// CHECK-LABEL: func @reshape_1D_0D_unsigned
// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]]
// CHECK: %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<1xui32> to tensor<1xi32>
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor<1xi32> into tensor<i32>
// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<i32> to tensor<ui32>
// CHECK: return %[[RET_UNSIGNED]] : tensor<ui32>
// -----
// CHECK-LABEL: func @reshape_3D_2D // CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> %0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>