From 329b1fd07163cf9a18451207d27e2290bb07dcd4 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 1 Mar 2021 07:59:26 -0800 Subject: [PATCH] Verify compatible shapes in unpack verification rather than exact Previously this would be too strict and fail if dynamic and static dims were compared. Dynamic/unknown are treated as "maybe equal" to a static value without further info, so at this layer don't flag as invalid unless truly are. PiperOrigin-RevId: 360189086 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h | 1 + include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 5 +++++ tests/ops.mlir | 18 ++++++++++-------- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index f1763c3..21e9c9f 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index cca165e..519873d 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -888,6 +888,11 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate", let hasCanonicalizer = 1; let hasFolder = 1; + let extraClassDeclaration = [{ + static bool isCompatibleReturnTypes(ArrayRef l, ArrayRef r) { + return succeeded(mlir::verifyCompatibleShapes(l, r)); + } + }]; } def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", diff --git a/tests/ops.mlir b/tests/ops.mlir index 93c4a76..358b760 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -368,6 +368,16 @@ func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { // ----- +// CHECK-LABEL: @concat_1D +// Verifies that an error is not thrown if the inferred type is compatible with +// the result type. +func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> { + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +} + +// ----- + func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> { // expected-error@+1 {{'mhlo.concatenate' op requires the same element type for all operands and results}} %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32> @@ -384,14 +394,6 @@ func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor< // ----- -func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> { - // expected-error@+1 {{op inferred type(s) 'tensor<*xi32>' are incompatible with return type(s) of operation 'tensor<3xi32>'}} - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32> - return %0 : tensor<3xi32> -} - -// ----- - func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> { // expected-error@+1 {{op inferred type(s) 'tensor<3xi32>' are incompatible with return type(s) of operation 'tensor<4xi32>'}} %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>