Fix importing variadic output (#152)
* Fix importing variadic output * clang-format
This commit is contained in:
parent
4f8fd9d1bf
commit
2b6befce87
|
@ -268,9 +268,13 @@ private:
|
||||||
if (node.output()[i].empty()) {
|
if (node.output()[i].empty()) {
|
||||||
outputTypes.emplace_back(builder_.getNoneType());
|
outputTypes.emplace_back(builder_.getNoneType());
|
||||||
} else {
|
} else {
|
||||||
if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) {
|
auto j = i;
|
||||||
|
// Variadic output is a single ODS result.
|
||||||
|
if (variadicOut)
|
||||||
|
j = 0;
|
||||||
|
if (j < outputMap.size() && outputMap[j] >= MAX_TYPE) {
|
||||||
// Mapping gives a connection with an input.
|
// Mapping gives a connection with an input.
|
||||||
mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType();
|
mlir::Type inputType = inputs[outputMap[j] - MAX_TYPE].getType();
|
||||||
if (inputType.isa<mlir::TensorType>()) {
|
if (inputType.isa<mlir::TensorType>()) {
|
||||||
auto elementType =
|
auto elementType =
|
||||||
inputType.cast<mlir::TensorType>().getElementType();
|
inputType.cast<mlir::TensorType>().getElementType();
|
||||||
|
@ -279,9 +283,9 @@ private:
|
||||||
} else {
|
} else {
|
||||||
outputTypes.push_back(inputType);
|
outputTypes.push_back(inputType);
|
||||||
}
|
}
|
||||||
} else if (i < outputMap.size() && outputMap[i] != -1) {
|
} else if (j < outputMap.size() && outputMap[j] != -1) {
|
||||||
// Mapping gives a direct type.
|
// Mapping gives a direct type.
|
||||||
auto elementType = buildTypeFromIndex(outputMap[i]);
|
auto elementType = buildTypeFromIndex(outputMap[j]);
|
||||||
auto outType = mlir::UnrankedTensorType::get(elementType);
|
auto outType = mlir::UnrankedTensorType::get(elementType);
|
||||||
outputTypes.emplace_back(outType);
|
outputTypes.emplace_back(outType);
|
||||||
} else {
|
} else {
|
||||||
|
@ -305,11 +309,18 @@ private:
|
||||||
op.getOperation())) {
|
op.getOperation())) {
|
||||||
auto outTypes = opWithTypeInference.resultTypeInference();
|
auto outTypes = opWithTypeInference.resultTypeInference();
|
||||||
for (int i = 0; i < node.output().size(); i++) {
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
|
if (variadicOut)
|
||||||
|
(*(op.getODSResults(0).begin() + i)).setType(outTypes[i]);
|
||||||
|
else
|
||||||
(*op.getODSResults(i).begin()).setType(outTypes[i]);
|
(*op.getODSResults(i).begin()).setType(outTypes[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < node.output().size(); i++) {
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
|
if (variadicOut)
|
||||||
|
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
|
||||||
|
*(op.getODSResults(0).begin() + i));
|
||||||
|
else
|
||||||
frontend_symbols_.AddMapping(
|
frontend_symbols_.AddMapping(
|
||||||
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
|
legalize_name(node.output()[i]), *(op.getODSResults(i).begin()));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue