Handle empty tensors in SimplifyConcatSlice.
If the result of the slice is an empty tensor, do nothing. This fixes a crash: we can't create a `concat` with an empty operand range. PiperOrigin-RevId: 378354956
This commit is contained in:
		
							parent
							
								
									1770ed455f
								
							
						
					
					
						commit
						d828b457b3
					
				| 
						 | 
					@ -3030,6 +3030,13 @@ struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
 | 
				
			||||||
      return failure();
 | 
					      return failure();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // If there's nothing to slice that means the output is an empty tensor and
 | 
				
			||||||
 | 
					    // there is dead code. We do nothing here and rely on other passes to clean
 | 
				
			||||||
 | 
					    // this up.
 | 
				
			||||||
 | 
					    if (subset_size == 0) {
 | 
				
			||||||
 | 
					      return failure();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (subset_size > 1 && !concat.getResult().hasOneUse()) {
 | 
					    if (subset_size > 1 && !concat.getResult().hasOneUse()) {
 | 
				
			||||||
      return failure();
 | 
					      return failure();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -391,6 +391,16 @@ func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg
 | 
				
			||||||
  return %1 : tensor<2x5xf32>
 | 
					  return %1 : tensor<2x5xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: slice_concat_empty
 | 
				
			||||||
 | 
					func @slice_concat_empty(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> {
 | 
				
			||||||
 | 
					  %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
 | 
				
			||||||
 | 
					  %1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<0x5xf32>)
 | 
				
			||||||
 | 
					  %2 = "mhlo.concatenate"(%1, %arg2) { dimension = 0 : i64 } : (tensor<0x5xf32>, tensor<1x5xf32>) -> tensor<1x5xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: return %arg2
 | 
				
			||||||
 | 
					  return %2 : tensor<1x5xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: func @broadcast_in_dim_identity
 | 
					// CHECK-LABEL: func @broadcast_in_dim_identity
 | 
				
			||||||
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
 | 
					func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
 | 
				
			||||||
  // CHECK: return %arg0
 | 
					  // CHECK: return %arg0
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue