diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 300b94b..2ab6353 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -268,9 +268,13 @@ private: if (node.output()[i].empty()) { outputTypes.emplace_back(builder_.getNoneType()); } else { - if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) { + auto j = i; + // Variadic output is a single ODS result. + if (variadicOut) + j = 0; + if (j < outputMap.size() && outputMap[j] >= MAX_TYPE) { // Mapping gives a connection with an input. - mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType(); + mlir::Type inputType = inputs[outputMap[j] - MAX_TYPE].getType(); if (inputType.isa()) { auto elementType = inputType.cast().getElementType(); @@ -279,9 +283,9 @@ private: } else { outputTypes.push_back(inputType); } - } else if (i < outputMap.size() && outputMap[i] != -1) { + } else if (j < outputMap.size() && outputMap[j] != -1) { // Mapping gives a direct type. - auto elementType = buildTypeFromIndex(outputMap[i]); + auto elementType = buildTypeFromIndex(outputMap[j]); auto outType = mlir::UnrankedTensorType::get(elementType); outputTypes.emplace_back(outType); } else { @@ -305,13 +309,20 @@ private: op.getOperation())) { auto outTypes = opWithTypeInference.resultTypeInference(); for (int i = 0; i < node.output().size(); i++) { - (*op.getODSResults(i).begin()).setType(outTypes[i]); + if (variadicOut) + (*(op.getODSResults(0).begin() + i)).setType(outTypes[i]); + else + (*op.getODSResults(i).begin()).setType(outTypes[i]); } } for (int i = 0; i < node.output().size(); i++) { - frontend_symbols_.AddMapping( - legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); + if (variadicOut) + frontend_symbols_.AddMapping(legalize_name(node.output()[i]), + *(op.getODSResults(0).begin() + i)); + else + frontend_symbols_.AddMapping( + legalize_name(node.output()[i]), *(op.getODSResults(i).begin())); } }