Support dynamic-shaped operand in verification of BroadcastInDim.
Verification of HLO_BroadcastInDimOp was previously failing or crashing if the operand had a dynamic shape or was unranked. Update the verification code to allow the operand to be unranked or have dynamic shape. PiperOrigin-RevId: 358056793
This commit is contained in:
parent
dd237d4267
commit
b579bd5d9e
|
@ -740,6 +740,12 @@ static LogicalResult Verify(BroadcastOp op) {
|
|||
|
||||
static LogicalResult Verify(BroadcastInDimOp op) {
|
||||
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
|
||||
if (!operandType) {
|
||||
// The following verification checks all depend on knowing the rank of
|
||||
// the operand. Bail out now if we don't know the rank of the operand.
|
||||
return success();
|
||||
}
|
||||
|
||||
auto operandRank = operandType.getRank();
|
||||
if (!op.broadcast_dimensions()) {
|
||||
if (operandRank == 0) {
|
||||
|
@ -783,6 +789,7 @@ static LogicalResult Verify(BroadcastInDimOp op) {
|
|||
dimIndex, resultRank));
|
||||
}
|
||||
|
||||
if (!operandType.isDynamicDim(i)) {
|
||||
auto dimSize = operandType.getDimSize(i);
|
||||
auto resultDimSize = resultType.getDimSize(dimIndex);
|
||||
if (dimSize != 1 && dimSize != resultDimSize) {
|
||||
|
@ -792,6 +799,7 @@ static LogicalResult Verify(BroadcastInDimOp op) {
|
|||
i, dimSize, dimIndex, resultDimSize));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -180,6 +180,30 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x
|
|||
|
||||
// -----
|
||||
|
||||
// Regression test for b/180052624, where this was improperly marked as an
|
||||
// invalid mhlo.broadcast_in_dim op.
|
||||
// CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand
|
||||
func @broadcast_in_dim_dynamic_shaped_operand(%arg0 : tensor<?xf32>) -> tensor<2xf32> {
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {
|
||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||
} : (tensor<?xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Regression test for b/180052624, where this crashed verification given the
|
||||
// unranked operand.
|
||||
// CHECK-LABEL: func @broadcast_in_dim_unranked_operand
|
||||
func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32> {
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {
|
||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||
} : (tensor<*xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
||||
// expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}}
|
||||
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
||||
|
|
Loading…
Reference in New Issue