Verify non-scalar inputs for HLO concat

XLA HLO concat does not accept scalars, so fail verification if this occurs. Avoids segfault when accessing an empty output shape.

PiperOrigin-RevId: 337618167
This commit is contained in:
Jacques Pienaar 2020-10-16 19:38:35 -07:00 committed by TensorFlow MLIR Team
parent 706718b4fb
commit 27968619b7
2 changed files with 11 additions and 0 deletions

View File

@ -1053,6 +1053,9 @@ LogicalResult ConcatenateOp::inferReturnTypes(
return success(); return success();
} }
if (first_type.getRank() == 0)
return emitOptionalError(location, "rank-0 values cannot be concatenated");
auto out_shape = llvm::to_vector<6>(first_type.getShape()); auto out_shape = llvm::to_vector<6>(first_type.getShape());
// Determine what the non-concatenate dimensions should be. // Determine what the non-concatenate dimensions should be.

View File

@ -328,6 +328,14 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<
// ----- // -----
func @concat_0D(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<2xi32> {
// expected-error@+1 {{rank-0 values cannot be concatenated}}
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
// -----
// CHECK-LABEL: @concat_1D // CHECK-LABEL: @concat_1D
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>