Update mhlo.concatenate inferred return type for dynamics

MHLO concatenate should support dynamic inputs. Its possible that the output
shape can be inferred from a dimension in one input that is not dynamic in
another.

PiperOrigin-RevId: 331054181
This commit is contained in:
Robert Suderman 2020-09-10 17:45:16 -07:00 committed by TensorFlow MLIR Team
parent 6eefb07767
commit d0c8d17373
1 changed files with 26 additions and 1 deletions

View File

@ -1005,13 +1005,38 @@ LogicalResult ConcatenateOp::inferReturnTypes(
} }
} }
// If an input is unranked the output shape is unranked. // Find the first ranked input to determine the output rank.
for (auto type : operands.getTypes()) {
auto shaped_type = type.cast<ShapedType>();
if (shaped_type.hasRank()) {
first_type = shaped_type;
break;
}
}
// If all inputs are unranked, the result must be unranked.
if (!first_type.hasRank()) { if (!first_type.hasRank()) {
inferredReturnTypes.push_back(UnrankedTensorType::get(out_element)); inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
return success(); return success();
} }
auto out_shape = llvm::to_vector<6>(first_type.getShape()); auto out_shape = llvm::to_vector<6>(first_type.getShape());
// Determine what the non-concatenate dimensions should be.
for (auto type : operands.getTypes()) {
auto shaped_ty = type.cast<ShapedType>();
if (!shaped_ty.hasRank()) {
continue;
}
for (auto it : llvm::enumerate(shaped_ty.getShape())) {
// If a dimension is not dynamic, the output shape should match.
if (ShapedType::isDynamic(out_shape[it.index()])) {
out_shape[it.index()] = it.value();
}
}
}
out_shape[dimension] = 0; out_shape[dimension] = 0;
for (auto operand : operands.getTypes()) { for (auto operand : operands.getTypes()) {