Verify compatible shapes in unpack verification rather than exact
Previously this would be too strict and fail if dynamic and static dims were compared. Dynamic/unknown are treated as "maybe equal" to a static value without further info, so at this layer don't flag as invalid unless truly are. PiperOrigin-RevId: 360189086
This commit is contained in:
parent
70ee9369d5
commit
329b1fd071
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
|
|
@ -888,6 +888,11 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate",
|
|||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r) {
|
||||
return succeeded(mlir::verifyCompatibleShapes(l, r));
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
|
||||
|
|
|
@ -368,6 +368,16 @@ func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @concat_1D
|
||||
// Verifies that an error is not thrown if the inferred type is compatible with
|
||||
// the result type.
|
||||
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> {
|
||||
// expected-error@+1 {{'mhlo.concatenate' op requires the same element type for all operands and results}}
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32>
|
||||
|
@ -384,14 +394,6 @@ func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<
|
|||
|
||||
// -----
|
||||
|
||||
func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
|
||||
// expected-error@+1 {{op inferred type(s) 'tensor<*xi32>' are incompatible with return type(s) of operation 'tensor<3xi32>'}}
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> {
|
||||
// expected-error@+1 {{op inferred type(s) 'tensor<3xi32>' are incompatible with return type(s) of operation 'tensor<4xi32>'}}
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
|
||||
|
|
Loading…
Reference in New Issue