Relax DynamicBroadcastInDim verifier when dimensions are dynamic.

For input and output dimensions which must match, we shouldn't fail in the case where one dim is dynamic and the other is static. This is insufficient information to conclude a dimension mismatch.

PiperOrigin-RevId: 325344738
This commit is contained in:
Lucy Fox 2020-08-06 17:29:29 -07:00 committed by Geoffrey Martin-Noble
parent c78e144f82
commit cd22ecd136
2 changed files with 29 additions and 3 deletions

View File

@ -748,10 +748,12 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) {
auto dimSize = operandType.getDimSize(i); auto dimSize = operandType.getDimSize(i);
auto resultDimSize = resultType.getDimSize(dimIndex); auto resultDimSize = resultType.getDimSize(dimIndex);
if (dimSize != 1 && dimSize != resultDimSize) { // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
// add a manual check for this.
if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
return op.emitOpError( return op.emitOpError(
llvm::formatv("size of operand dimension {0} ({1}) is not equal to " llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
"1 or size of result dimension {2} ({3})", "with size of result dimension {2} ({3})",
i, dimSize, dimIndex, resultDimSize)); i, dimSize, dimIndex, resultDimSize));
} }
} }

View File

@ -116,6 +116,30 @@ func @dynamic_broadcast_in_dim(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) ->
// ----- // -----
// CHECK-LABEL: func @dynamic_broadcast_in_dim_unknown_dim
func @dynamic_broadcast_in_dim_unknown_dim(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor<?x?x?xf32> {
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// -----
// CHECK-LABEL: func @dynamic_broadcast_in_dim_ok_dim
func @dynamic_broadcast_in_dim_ok_dim(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> {
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
return %0 : tensor<7x8x9xf32>
}
// -----
func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> {
// expected-error@+1 {{size of operand dimension 0 (32) is not compatible with size of result dimension 2 (9)}}
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
return %0 : tensor<7x8x9xf32>
}
// -----
func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
// expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}} // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}}
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>