Support dynamic-shaped operand in verification of BroadcastInDim.
Verification of HLO_BroadcastInDimOp was previously failing or crashing if the operand had a dynamic shape or was unranked. Update the verification code to allow the operand to be unranked or have dynamic shape. PiperOrigin-RevId: 358056793
This commit is contained in:
parent
dd237d4267
commit
b579bd5d9e
|
@ -740,6 +740,12 @@ static LogicalResult Verify(BroadcastOp op) {
|
||||||
|
|
||||||
static LogicalResult Verify(BroadcastInDimOp op) {
|
static LogicalResult Verify(BroadcastInDimOp op) {
|
||||||
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
|
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!operandType) {
|
||||||
|
// The following verification checks all depend on knowing the rank of
|
||||||
|
// the operand. Bail out now if we don't know the rank of the operand.
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
auto operandRank = operandType.getRank();
|
auto operandRank = operandType.getRank();
|
||||||
if (!op.broadcast_dimensions()) {
|
if (!op.broadcast_dimensions()) {
|
||||||
if (operandRank == 0) {
|
if (operandRank == 0) {
|
||||||
|
@ -783,6 +789,7 @@ static LogicalResult Verify(BroadcastInDimOp op) {
|
||||||
dimIndex, resultRank));
|
dimIndex, resultRank));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!operandType.isDynamicDim(i)) {
|
||||||
auto dimSize = operandType.getDimSize(i);
|
auto dimSize = operandType.getDimSize(i);
|
||||||
auto resultDimSize = resultType.getDimSize(dimIndex);
|
auto resultDimSize = resultType.getDimSize(dimIndex);
|
||||||
if (dimSize != 1 && dimSize != resultDimSize) {
|
if (dimSize != 1 && dimSize != resultDimSize) {
|
||||||
|
@ -792,6 +799,7 @@ static LogicalResult Verify(BroadcastInDimOp op) {
|
||||||
i, dimSize, dimIndex, resultDimSize));
|
i, dimSize, dimIndex, resultDimSize));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -180,6 +180,30 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Regression test for b/180052624, where this was improperly marked as an
|
||||||
|
// invalid mhlo.broadcast_in_dim op.
|
||||||
|
// CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand
|
||||||
|
func @broadcast_in_dim_dynamic_shaped_operand(%arg0 : tensor<?xf32>) -> tensor<2xf32> {
|
||||||
|
%0 = "mhlo.broadcast_in_dim"(%arg0) {
|
||||||
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
|
} : (tensor<?xf32>) -> tensor<2xf32>
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Regression test for b/180052624, where this crashed verification given the
|
||||||
|
// unranked operand.
|
||||||
|
// CHECK-LABEL: func @broadcast_in_dim_unranked_operand
|
||||||
|
func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32> {
|
||||||
|
%0 = "mhlo.broadcast_in_dim"(%arg0) {
|
||||||
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
|
} : (tensor<*xf32>) -> tensor<2xf32>
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
||||||
// expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}}
|
// expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}}
|
||||||
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
||||||
|
|
Loading…
Reference in New Issue