From d764806c1ef44996f049bd074b3f4765af48708b Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Wed, 12 May 2021 15:13:20 -0700 Subject: [PATCH] [MHLO:Linalg] Add support for lowering reshape of unsigned tensors PiperOrigin-RevId: 373461627 --- .../mhlo/transforms/legalize_to_linalg.cc | 3 +++ tests/hlo-legalize-to-linalg.mlir | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 2dedf0e..d74994f 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -782,6 +782,9 @@ class ReshapeOpConverter : public OpConversionPattern { if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) return failure(); + result_type = this->typeConverter->convertType(result_type) + .template cast(); + // Compute the reassociation maps for the linalg operation. ArrayRef src_shape = (operand_type.getRank() > result_type.getRank() diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index f8de6f2..acd425d 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -482,6 +482,19 @@ func @reshape_0D_1D(%arg0: tensor) -> tensor<1xi32> { // ----- +func @reshape_0D_1D_unsigned(%arg0: tensor) -> tensor<1xui32> { + %0 = "mhlo.reshape"(%arg0) : (tensor) -> tensor<1xui32> + return %0 : tensor<1xui32> +} +// CHECK-LABEL: func @reshape_0D_1D_unsigned +// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]] +// CHECK: %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor to tensor +// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor into tensor<1xi32> +// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<1xi32> to tensor<1xui32> +// CHECK: return %[[RET_UNSIGNED]] : tensor<1xui32> + +// ----- + // CHECK-LABEL: func @reshape_1D_0D func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor { %0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor @@ -491,6 +504,19 @@ func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor { // ----- +func @reshape_1D_0D_unsigned(%arg0: tensor<1xui32>) -> tensor { + %0 = "mhlo.reshape"(%arg0) : (tensor<1xui32>) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @reshape_1D_0D_unsigned +// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]] +// CHECK: %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<1xui32> to tensor<1xi32> +// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor<1xi32> into tensor +// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor to tensor +// CHECK: return %[[RET_UNSIGNED]] : tensor + +// ----- + // CHECK-LABEL: func @reshape_3D_2D func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { %0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>