From 27968619b73897e0f79a7629d353f493b828e5b1 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 16 Oct 2020 19:38:35 -0700 Subject: [PATCH] Verify non-scalar inputs for HLO concat XLA HLO concat does not accept scalars, so fail verification if this occurs. Avoids segfault when accessing an empty output shape. PiperOrigin-RevId: 337618167 --- lib/Dialect/mhlo/IR/hlo_ops.cc | 3 +++ tests/ops.mlir | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index f8a92cc..241b593 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1053,6 +1053,9 @@ LogicalResult ConcatenateOp::inferReturnTypes( return success(); } + if (first_type.getRank() == 0) + return emitOptionalError(location, "rank-0 values cannot be concatenated"); + auto out_shape = llvm::to_vector<6>(first_type.getShape()); // Determine what the non-concatenate dimensions should be. diff --git a/tests/ops.mlir b/tests/ops.mlir index 4462d9c..fb4ab62 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -328,6 +328,14 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor< // ----- +func @concat_0D(%arg0: tensor, %arg1: tensor) -> tensor<2xi32> { + // expected-error@+1 {{rank-0 values cannot be concatenated}} + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor, tensor) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + // CHECK-LABEL: @concat_1D func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>