Support dynamic-shaped operand in verification of BroadcastInDim.

Verification of HLO_BroadcastInDimOp was previously failing or crashing if the
operand had a dynamic shape or was unranked. Update the verification code to
allow the operand to be unranked or have dynamic shape.

PiperOrigin-RevId: 358056793
This commit is contained in:
Richard Uhler 2021-02-17 16:16:32 -08:00 committed by TensorFlow MLIR Team
parent dd237d4267
commit b579bd5d9e
2 changed files with 39 additions and 7 deletions

View File

@ -740,6 +740,12 @@ static LogicalResult Verify(BroadcastOp op) {
static LogicalResult Verify(BroadcastInDimOp op) {
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
if (!operandType) {
// The following verification checks all depend on knowing the rank of
// the operand. Bail out now if we don't know the rank of the operand.
return success();
}
auto operandRank = operandType.getRank();
if (!op.broadcast_dimensions()) {
if (operandRank == 0) {
@ -783,6 +789,7 @@ static LogicalResult Verify(BroadcastInDimOp op) {
dimIndex, resultRank));
}
if (!operandType.isDynamicDim(i)) {
auto dimSize = operandType.getDimSize(i);
auto resultDimSize = resultType.getDimSize(dimIndex);
if (dimSize != 1 && dimSize != resultDimSize) {
@ -792,6 +799,7 @@ static LogicalResult Verify(BroadcastInDimOp op) {
i, dimSize, dimIndex, resultDimSize));
}
}
}
return success();
}

View File

@ -180,6 +180,30 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x
// -----
// Regression test for b/180052624, where this was improperly marked as an
// invalid mhlo.broadcast_in_dim op.
// CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand
func @broadcast_in_dim_dynamic_shaped_operand(%arg0 : tensor<?xf32>) -> tensor<2xf32> {
%0 = "mhlo.broadcast_in_dim"(%arg0) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<?xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// Regression test for b/180052624, where this crashed verification given the
// unranked operand.
// CHECK-LABEL: func @broadcast_in_dim_unranked_operand
func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32> {
%0 = "mhlo.broadcast_in_dim"(%arg0) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<*xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
// expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}}
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {