Fix importing variadic output (#152)

* Fix importing variadic output

* clang-format
This commit is contained in:
Tung D. Le 2020-05-27 22:33:53 +09:00 committed by GitHub
parent 4f8fd9d1bf
commit 2b6befce87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 7 deletions

View File

@ -268,9 +268,13 @@ private:
if (node.output()[i].empty()) { if (node.output()[i].empty()) {
outputTypes.emplace_back(builder_.getNoneType()); outputTypes.emplace_back(builder_.getNoneType());
} else { } else {
if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) { auto j = i;
// Variadic output is a single ODS result.
if (variadicOut)
j = 0;
if (j < outputMap.size() && outputMap[j] >= MAX_TYPE) {
// Mapping gives a connection with an input. // Mapping gives a connection with an input.
mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType(); mlir::Type inputType = inputs[outputMap[j] - MAX_TYPE].getType();
if (inputType.isa<mlir::TensorType>()) { if (inputType.isa<mlir::TensorType>()) {
auto elementType = auto elementType =
inputType.cast<mlir::TensorType>().getElementType(); inputType.cast<mlir::TensorType>().getElementType();
@ -279,9 +283,9 @@ private:
} else { } else {
outputTypes.push_back(inputType); outputTypes.push_back(inputType);
} }
} else if (i < outputMap.size() && outputMap[i] != -1) { } else if (j < outputMap.size() && outputMap[j] != -1) {
// Mapping gives a direct type. // Mapping gives a direct type.
auto elementType = buildTypeFromIndex(outputMap[i]); auto elementType = buildTypeFromIndex(outputMap[j]);
auto outType = mlir::UnrankedTensorType::get(elementType); auto outType = mlir::UnrankedTensorType::get(elementType);
outputTypes.emplace_back(outType); outputTypes.emplace_back(outType);
} else { } else {
@ -305,11 +309,18 @@ private:
op.getOperation())) { op.getOperation())) {
auto outTypes = opWithTypeInference.resultTypeInference(); auto outTypes = opWithTypeInference.resultTypeInference();
for (int i = 0; i < node.output().size(); i++) { for (int i = 0; i < node.output().size(); i++) {
if (variadicOut)
(*(op.getODSResults(0).begin() + i)).setType(outTypes[i]);
else
(*op.getODSResults(i).begin()).setType(outTypes[i]); (*op.getODSResults(i).begin()).setType(outTypes[i]);
} }
} }
for (int i = 0; i < node.output().size(); i++) { for (int i = 0; i < node.output().size(); i++) {
if (variadicOut)
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
*(op.getODSResults(0).begin() + i));
else
frontend_symbols_.AddMapping( frontend_symbols_.AddMapping(
legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
} }