diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index f91c9a1..14ac3da 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1005,13 +1005,38 @@ LogicalResult ConcatenateOp::inferReturnTypes( } } - // If an input is unranked the output shape is unranked. + // Find the first ranked input to determine the output rank. + for (auto type : operands.getTypes()) { + auto shaped_type = type.cast(); + if (shaped_type.hasRank()) { + first_type = shaped_type; + break; + } + } + + // If all inputs are unranked, the result must be unranked. if (!first_type.hasRank()) { inferredReturnTypes.push_back(UnrankedTensorType::get(out_element)); return success(); } auto out_shape = llvm::to_vector<6>(first_type.getShape()); + + // Determine what the non-concatenate dimensions should be. + for (auto type : operands.getTypes()) { + auto shaped_ty = type.cast(); + if (!shaped_ty.hasRank()) { + continue; + } + + for (auto it : llvm::enumerate(shaped_ty.getShape())) { + // If a dimension is not dynamic, the output shape should match. + if (ShapedType::isDynamic(out_shape[it.index()])) { + out_shape[it.index()] = it.value(); + } + } + } + out_shape[dimension] = 0; for (auto operand : operands.getTypes()) {