[MHLO:Linalg] Add support for lowering reshape of unsigned tensors
PiperOrigin-RevId: 373461627
This commit is contained in:
parent
a2c9b3c9d7
commit
d764806c1e
|
@ -782,6 +782,9 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||||
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
|
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
result_type = this->typeConverter->convertType(result_type)
|
||||||
|
.template cast<ShapedType>();
|
||||||
|
|
||||||
// Compute the reassociation maps for the linalg operation.
|
// Compute the reassociation maps for the linalg operation.
|
||||||
ArrayRef<int64_t> src_shape =
|
ArrayRef<int64_t> src_shape =
|
||||||
(operand_type.getRank() > result_type.getRank()
|
(operand_type.getRank() > result_type.getRank()
|
||||||
|
|
|
@ -482,6 +482,19 @@ func @reshape_0D_1D(%arg0: tensor<i32>) -> tensor<1xi32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @reshape_0D_1D_unsigned(%arg0: tensor<ui32>) -> tensor<1xui32> {
|
||||||
|
%0 = "mhlo.reshape"(%arg0) : (tensor<ui32>) -> tensor<1xui32>
|
||||||
|
return %0 : tensor<1xui32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @reshape_0D_1D_unsigned
|
||||||
|
// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<ui32> to tensor<i32>
|
||||||
|
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor<i32> into tensor<1xi32>
|
||||||
|
// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<1xi32> to tensor<1xui32>
|
||||||
|
// CHECK: return %[[RET_UNSIGNED]] : tensor<1xui32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @reshape_1D_0D
|
// CHECK-LABEL: func @reshape_1D_0D
|
||||||
func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> {
|
func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> {
|
||||||
%0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32>
|
%0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32>
|
||||||
|
@ -491,6 +504,19 @@ func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @reshape_1D_0D_unsigned(%arg0: tensor<1xui32>) -> tensor<ui32> {
|
||||||
|
%0 = "mhlo.reshape"(%arg0) : (tensor<1xui32>) -> tensor<ui32>
|
||||||
|
return %0 : tensor<ui32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @reshape_1D_0D_unsigned
|
||||||
|
// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<1xui32> to tensor<1xi32>
|
||||||
|
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor<1xi32> into tensor<i32>
|
||||||
|
// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<i32> to tensor<ui32>
|
||||||
|
// CHECK: return %[[RET_UNSIGNED]] : tensor<ui32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @reshape_3D_2D
|
// CHECK-LABEL: func @reshape_3D_2D
|
||||||
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
|
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
|
||||||
%0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
|
%0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
|
||||||
|
|
Loading…
Reference in New Issue