From 2b6befce874ae0271062407a0b8079ef8b49375f Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Wed, 27 May 2020 22:33:53 +0900 Subject: [PATCH] Fix importing variadic output (#152) * Fix importing variadic output * clang-format --- src/Builder/FrontendDialectTransformer.cpp | 25 ++++++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 300b94b..2ab6353 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -268,9 +268,13 @@ private: if (node.output()[i].empty()) { outputTypes.emplace_back(builder_.getNoneType()); } else { - if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) { + auto j = i; + // Variadic output is a single ODS result. + if (variadicOut) + j = 0; + if (j < outputMap.size() && outputMap[j] >= MAX_TYPE) { // Mapping gives a connection with an input. - mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType(); + mlir::Type inputType = inputs[outputMap[j] - MAX_TYPE].getType(); if (inputType.isa()) { auto elementType = inputType.cast().getElementType(); @@ -279,9 +283,9 @@ private: } else { outputTypes.push_back(inputType); } - } else if (i < outputMap.size() && outputMap[i] != -1) { + } else if (j < outputMap.size() && outputMap[j] != -1) { // Mapping gives a direct type. - auto elementType = buildTypeFromIndex(outputMap[i]); + auto elementType = buildTypeFromIndex(outputMap[j]); auto outType = mlir::UnrankedTensorType::get(elementType); outputTypes.emplace_back(outType); } else { @@ -305,13 +309,20 @@ private: op.getOperation())) { auto outTypes = opWithTypeInference.resultTypeInference(); for (int i = 0; i < node.output().size(); i++) { - (*op.getODSResults(i).begin()).setType(outTypes[i]); + if (variadicOut) + (*(op.getODSResults(0).begin() + i)).setType(outTypes[i]); + else + (*op.getODSResults(i).begin()).setType(outTypes[i]); } } for (int i = 0; i < node.output().size(); i++) { - frontend_symbols_.AddMapping( - legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); + if (variadicOut) + frontend_symbols_.AddMapping(legalize_name(node.output()[i]), + *(op.getODSResults(0).begin() + i)); + else + frontend_symbols_.AddMapping( + legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); } }