Verify compatible shapes in unpack verification rather than exact
Previously this would be too strict and fail if dynamic and static dims were compared. Dynamic/unknown are treated as "maybe equal" to a static value without further info, so at this layer don't flag as invalid unless truly are. PiperOrigin-RevId: 360189086
This commit is contained in:
parent
70ee9369d5
commit
329b1fd071
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
|
|
@ -888,6 +888,11 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate",
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r) {
|
||||||
|
return succeeded(mlir::verifyCompatibleShapes(l, r));
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
|
def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
|
||||||
|
|
|
@ -368,6 +368,16 @@ func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @concat_1D
|
||||||
|
// Verifies that an error is not thrown if the inferred type is compatible with
|
||||||
|
// the result type.
|
||||||
|
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
|
||||||
|
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
|
||||||
|
return %0 : tensor<3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> {
|
func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> {
|
||||||
// expected-error@+1 {{'mhlo.concatenate' op requires the same element type for all operands and results}}
|
// expected-error@+1 {{'mhlo.concatenate' op requires the same element type for all operands and results}}
|
||||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32>
|
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32>
|
||||||
|
@ -384,14 +394,6 @@ func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
|
|
||||||
// expected-error@+1 {{op inferred type(s) 'tensor<*xi32>' are incompatible with return type(s) of operation 'tensor<3xi32>'}}
|
|
||||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
|
|
||||||
return %0 : tensor<3xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> {
|
func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> {
|
||||||
// expected-error@+1 {{op inferred type(s) 'tensor<3xi32>' are incompatible with return type(s) of operation 'tensor<4xi32>'}}
|
// expected-error@+1 {{op inferred type(s) 'tensor<3xi32>' are incompatible with return type(s) of operation 'tensor<4xi32>'}}
|
||||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
|
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
|
||||||
|
|
Loading…
Reference in New Issue