Verify non-scalar inputs for HLO concat
XLA HLO concat does not accept scalars, so fail verification if this occurs. Avoids segfault when accessing an empty output shape. PiperOrigin-RevId: 337618167
This commit is contained in:
parent
706718b4fb
commit
27968619b7
|
@ -1053,6 +1053,9 @@ LogicalResult ConcatenateOp::inferReturnTypes(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (first_type.getRank() == 0)
|
||||||
|
return emitOptionalError(location, "rank-0 values cannot be concatenated");
|
||||||
|
|
||||||
auto out_shape = llvm::to_vector<6>(first_type.getShape());
|
auto out_shape = llvm::to_vector<6>(first_type.getShape());
|
||||||
|
|
||||||
// Determine what the non-concatenate dimensions should be.
|
// Determine what the non-concatenate dimensions should be.
|
||||||
|
|
|
@ -328,6 +328,14 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @concat_0D(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<2xi32> {
|
||||||
|
// expected-error@+1 {{rank-0 values cannot be concatenated}}
|
||||||
|
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
|
||||||
|
return %0 : tensor<2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @concat_1D
|
// CHECK-LABEL: @concat_1D
|
||||||
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
|
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
|
||||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
|
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
|
||||||
|
|
Loading…
Reference in New Issue