Fix importing variadic output (#152)
* Fix importing variadic output * clang-format
This commit is contained in:
parent
4f8fd9d1bf
commit
2b6befce87
|
@ -268,9 +268,13 @@ private:
|
|||
if (node.output()[i].empty()) {
|
||||
outputTypes.emplace_back(builder_.getNoneType());
|
||||
} else {
|
||||
if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) {
|
||||
auto j = i;
|
||||
// Variadic output is a single ODS result.
|
||||
if (variadicOut)
|
||||
j = 0;
|
||||
if (j < outputMap.size() && outputMap[j] >= MAX_TYPE) {
|
||||
// Mapping gives a connection with an input.
|
||||
mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType();
|
||||
mlir::Type inputType = inputs[outputMap[j] - MAX_TYPE].getType();
|
||||
if (inputType.isa<mlir::TensorType>()) {
|
||||
auto elementType =
|
||||
inputType.cast<mlir::TensorType>().getElementType();
|
||||
|
@ -279,9 +283,9 @@ private:
|
|||
} else {
|
||||
outputTypes.push_back(inputType);
|
||||
}
|
||||
} else if (i < outputMap.size() && outputMap[i] != -1) {
|
||||
} else if (j < outputMap.size() && outputMap[j] != -1) {
|
||||
// Mapping gives a direct type.
|
||||
auto elementType = buildTypeFromIndex(outputMap[i]);
|
||||
auto elementType = buildTypeFromIndex(outputMap[j]);
|
||||
auto outType = mlir::UnrankedTensorType::get(elementType);
|
||||
outputTypes.emplace_back(outType);
|
||||
} else {
|
||||
|
@ -305,13 +309,20 @@ private:
|
|||
op.getOperation())) {
|
||||
auto outTypes = opWithTypeInference.resultTypeInference();
|
||||
for (int i = 0; i < node.output().size(); i++) {
|
||||
(*op.getODSResults(i).begin()).setType(outTypes[i]);
|
||||
if (variadicOut)
|
||||
(*(op.getODSResults(0).begin() + i)).setType(outTypes[i]);
|
||||
else
|
||||
(*op.getODSResults(i).begin()).setType(outTypes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < node.output().size(); i++) {
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
|
||||
if (variadicOut)
|
||||
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
|
||||
*(op.getODSResults(0).begin() + i));
|
||||
else
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue