Verify compatible shapes in unpack verification rather than exact

Previously this would be too strict and fail if dynamic and static dims were
compared. Dynamic/unknown are treated as "maybe equal" to a static value without further info, so at this layer don't flag as invalid unless truly are.

PiperOrigin-RevId: 360189086
This commit is contained in:
Jacques Pienaar 2021-03-01 07:59:26 -08:00 committed by TensorFlow MLIR Team
parent 70ee9369d5
commit 329b1fd071
3 changed files with 16 additions and 8 deletions

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -888,6 +888,11 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate",
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let extraClassDeclaration = [{
static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r) {
return succeeded(mlir::verifyCompatibleShapes(l, r));
}
}];
} }
def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",

View File

@ -368,6 +368,16 @@ func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
// ----- // -----
// CHECK-LABEL: @concat_1D
// Verifies that an error is not thrown if the inferred type is compatible with
// the result type.
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> { func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> {
// expected-error@+1 {{'mhlo.concatenate' op requires the same element type for all operands and results}} // expected-error@+1 {{'mhlo.concatenate' op requires the same element type for all operands and results}}
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32> %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32>
@ -384,14 +394,6 @@ func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<
// ----- // -----
func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
// expected-error@+1 {{op inferred type(s) 'tensor<*xi32>' are incompatible with return type(s) of operation 'tensor<3xi32>'}}
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> { func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> {
// expected-error@+1 {{op inferred type(s) 'tensor<3xi32>' are incompatible with return type(s) of operation 'tensor<4xi32>'}} // expected-error@+1 {{op inferred type(s) 'tensor<3xi32>' are incompatible with return type(s) of operation 'tensor<4xi32>'}}
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32> %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>