Fix importing variadic output (#152)

* Fix importing variadic output

* clang-format
This commit is contained in:
Tung D. Le 2020-05-27 22:33:53 +09:00 committed by GitHub
parent 4f8fd9d1bf
commit 2b6befce87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 7 deletions

View File

@ -268,9 +268,13 @@ private:
if (node.output()[i].empty()) {
outputTypes.emplace_back(builder_.getNoneType());
} else {
if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) {
auto j = i;
// Variadic output is a single ODS result.
if (variadicOut)
j = 0;
if (j < outputMap.size() && outputMap[j] >= MAX_TYPE) {
// Mapping gives a connection with an input.
mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType();
mlir::Type inputType = inputs[outputMap[j] - MAX_TYPE].getType();
if (inputType.isa<mlir::TensorType>()) {
auto elementType =
inputType.cast<mlir::TensorType>().getElementType();
@ -279,9 +283,9 @@ private:
} else {
outputTypes.push_back(inputType);
}
} else if (i < outputMap.size() && outputMap[i] != -1) {
} else if (j < outputMap.size() && outputMap[j] != -1) {
// Mapping gives a direct type.
auto elementType = buildTypeFromIndex(outputMap[i]);
auto elementType = buildTypeFromIndex(outputMap[j]);
auto outType = mlir::UnrankedTensorType::get(elementType);
outputTypes.emplace_back(outType);
} else {
@ -305,13 +309,20 @@ private:
op.getOperation())) {
auto outTypes = opWithTypeInference.resultTypeInference();
for (int i = 0; i < node.output().size(); i++) {
(*op.getODSResults(i).begin()).setType(outTypes[i]);
if (variadicOut)
(*(op.getODSResults(0).begin() + i)).setType(outTypes[i]);
else
(*op.getODSResults(i).begin()).setType(outTypes[i]);
}
}
for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
if (variadicOut)
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
*(op.getODSResults(0).begin() + i));
else
frontend_symbols_.AddMapping(
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
}
}