Support dynamic-shaped operand in verification of BroadcastInDim.
Verification of HLO_BroadcastInDimOp was previously failing or crashing if the operand had a dynamic shape or was unranked. Update the verification code to allow the operand to be unranked or have dynamic shape. PiperOrigin-RevId: 358056793
This commit is contained in:
		
							parent
							
								
									dd237d4267
								
							
						
					
					
						commit
						b579bd5d9e
					
				| 
						 | 
					@ -740,6 +740,12 @@ static LogicalResult Verify(BroadcastOp op) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static LogicalResult Verify(BroadcastInDimOp op) {
 | 
					static LogicalResult Verify(BroadcastInDimOp op) {
 | 
				
			||||||
  auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
 | 
					  auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
 | 
				
			||||||
 | 
					  if (!operandType) {
 | 
				
			||||||
 | 
					    // The following verification checks all depend on knowing the rank of
 | 
				
			||||||
 | 
					    // the operand. Bail out now if we don't know the rank of the operand.
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto operandRank = operandType.getRank();
 | 
					  auto operandRank = operandType.getRank();
 | 
				
			||||||
  if (!op.broadcast_dimensions()) {
 | 
					  if (!op.broadcast_dimensions()) {
 | 
				
			||||||
    if (operandRank == 0) {
 | 
					    if (operandRank == 0) {
 | 
				
			||||||
| 
						 | 
					@ -783,6 +789,7 @@ static LogicalResult Verify(BroadcastInDimOp op) {
 | 
				
			||||||
                        dimIndex, resultRank));
 | 
					                        dimIndex, resultRank));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (!operandType.isDynamicDim(i)) {
 | 
				
			||||||
      auto dimSize = operandType.getDimSize(i);
 | 
					      auto dimSize = operandType.getDimSize(i);
 | 
				
			||||||
      auto resultDimSize = resultType.getDimSize(dimIndex);
 | 
					      auto resultDimSize = resultType.getDimSize(dimIndex);
 | 
				
			||||||
      if (dimSize != 1 && dimSize != resultDimSize) {
 | 
					      if (dimSize != 1 && dimSize != resultDimSize) {
 | 
				
			||||||
| 
						 | 
					@ -792,6 +799,7 @@ static LogicalResult Verify(BroadcastInDimOp op) {
 | 
				
			||||||
                          i, dimSize, dimIndex, resultDimSize));
 | 
					                          i, dimSize, dimIndex, resultDimSize));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return success();
 | 
					  return success();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -180,6 +180,30 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Regression test for b/180052624, where this was improperly marked as an
 | 
				
			||||||
 | 
					// invalid mhlo.broadcast_in_dim op.
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand
 | 
				
			||||||
 | 
					func @broadcast_in_dim_dynamic_shaped_operand(%arg0 : tensor<?xf32>) -> tensor<2xf32> {
 | 
				
			||||||
 | 
					  %0 = "mhlo.broadcast_in_dim"(%arg0) {
 | 
				
			||||||
 | 
					    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
				
			||||||
 | 
					  } : (tensor<?xf32>) -> tensor<2xf32>
 | 
				
			||||||
 | 
					  return %0 : tensor<2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Regression test for b/180052624, where this crashed verification given the
 | 
				
			||||||
 | 
					// unranked operand.
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @broadcast_in_dim_unranked_operand
 | 
				
			||||||
 | 
					func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32> {
 | 
				
			||||||
 | 
					  %0 = "mhlo.broadcast_in_dim"(%arg0) {
 | 
				
			||||||
 | 
					    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
				
			||||||
 | 
					  } : (tensor<*xf32>) -> tensor<2xf32>
 | 
				
			||||||
 | 
					  return %0 : tensor<2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
 | 
					func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
 | 
				
			||||||
  // expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}}
 | 
					  // expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}}
 | 
				
			||||||
  %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
 | 
					  %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue