Verify non-scalar inputs for HLO concat
XLA HLO concat does not accept scalars, so fail verification if this occurs. Avoids segfault when accessing an empty output shape. PiperOrigin-RevId: 337618167
This commit is contained in:
parent
706718b4fb
commit
27968619b7
|
@ -1053,6 +1053,9 @@ LogicalResult ConcatenateOp::inferReturnTypes(
|
|||
return success();
|
||||
}
|
||||
|
||||
if (first_type.getRank() == 0)
|
||||
return emitOptionalError(location, "rank-0 values cannot be concatenated");
|
||||
|
||||
auto out_shape = llvm::to_vector<6>(first_type.getShape());
|
||||
|
||||
// Determine what the non-concatenate dimensions should be.
|
||||
|
|
|
@ -328,6 +328,14 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<
|
|||
|
||||
// -----
|
||||
|
||||
func @concat_0D(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<2xi32> {
|
||||
// expected-error@+1 {{rank-0 values cannot be concatenated}}
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @concat_1D
|
||||
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
|
||||
|
|
Loading…
Reference in New Issue