Update mhlo.concatenate inferred return type for dynamics
MHLO concatenate should support dynamic inputs. Its possible that the output shape can be inferred from a dimension in one input that is not dynamic in another. PiperOrigin-RevId: 331054181
This commit is contained in:
parent
6eefb07767
commit
d0c8d17373
|
@ -1005,13 +1005,38 @@ LogicalResult ConcatenateOp::inferReturnTypes(
|
|||
}
|
||||
}
|
||||
|
||||
// If an input is unranked the output shape is unranked.
|
||||
// Find the first ranked input to determine the output rank.
|
||||
for (auto type : operands.getTypes()) {
|
||||
auto shaped_type = type.cast<ShapedType>();
|
||||
if (shaped_type.hasRank()) {
|
||||
first_type = shaped_type;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If all inputs are unranked, the result must be unranked.
|
||||
if (!first_type.hasRank()) {
|
||||
inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
|
||||
return success();
|
||||
}
|
||||
|
||||
auto out_shape = llvm::to_vector<6>(first_type.getShape());
|
||||
|
||||
// Determine what the non-concatenate dimensions should be.
|
||||
for (auto type : operands.getTypes()) {
|
||||
auto shaped_ty = type.cast<ShapedType>();
|
||||
if (!shaped_ty.hasRank()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto it : llvm::enumerate(shaped_ty.getShape())) {
|
||||
// If a dimension is not dynamic, the output shape should match.
|
||||
if (ShapedType::isDynamic(out_shape[it.index()])) {
|
||||
out_shape[it.index()] = it.value();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out_shape[dimension] = 0;
|
||||
|
||||
for (auto operand : operands.getTypes()) {
|
||||
|
|
Loading…
Reference in New Issue