Relax DynamicBroadcastInDim verifier when dimensions are dynamic.
For input and output dimensions which must match, we shouldn't fail in the case where one dim is dynamic and the other is static. This is insufficient information to conclude a dimension mismatch. PiperOrigin-RevId: 325344738
This commit is contained in:
parent
c78e144f82
commit
cd22ecd136
|
@ -748,10 +748,12 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) {
|
|||
|
||||
auto dimSize = operandType.getDimSize(i);
|
||||
auto resultDimSize = resultType.getDimSize(dimIndex);
|
||||
if (dimSize != 1 && dimSize != resultDimSize) {
|
||||
// Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
|
||||
// add a manual check for this.
|
||||
if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
|
||||
return op.emitOpError(
|
||||
llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
|
||||
"1 or size of result dimension {2} ({3})",
|
||||
llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
|
||||
"with size of result dimension {2} ({3})",
|
||||
i, dimSize, dimIndex, resultDimSize));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -116,6 +116,30 @@ func @dynamic_broadcast_in_dim(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) ->
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim_unknown_dim
|
||||
func @dynamic_broadcast_in_dim_unknown_dim(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor<?x?x?xf32> {
|
||||
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim_ok_dim
|
||||
func @dynamic_broadcast_in_dim_ok_dim(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> {
|
||||
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
|
||||
return %0 : tensor<7x8x9xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> {
|
||||
// expected-error@+1 {{size of operand dimension 0 (32) is not compatible with size of result dimension 2 (9)}}
|
||||
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
|
||||
return %0 : tensor<7x8x9xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
|
||||
// expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}}
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
|
||||
|
|
Loading…
Reference in New Issue