Import optional outputs as NoneType (#57)
* Import optional outputs as NoneType * Allow NoneType results after the shape inference * Use empty() to check an empty string
This commit is contained in:
parent
55cbe316fd
commit
c8758545e7
|
@ -248,15 +248,31 @@ private:
|
||||||
bool variadicIn = expectedNumOperands == -1;
|
bool variadicIn = expectedNumOperands == -1;
|
||||||
bool variadicOut = expectedNumResults == -1;
|
bool variadicOut = expectedNumResults == -1;
|
||||||
|
|
||||||
|
// In ONNX, there are two ways to leave an optional input or output
|
||||||
|
// unspecified: the first, available only for trailing inputs and outputs,
|
||||||
|
// is to simply not provide that input; the second method is to use an empty
|
||||||
|
// string in place of an input or output name.
|
||||||
|
//
|
||||||
|
// Here, we import optional inputs and outputs as NoneType.
|
||||||
|
|
||||||
|
// Trailing optional inputs.
|
||||||
if (!variadicIn)
|
if (!variadicIn)
|
||||||
for (auto i = inputs.size(); i < expectedNumOperands; i++)
|
for (auto i = inputs.size(); i < expectedNumOperands; i++)
|
||||||
inputs.emplace_back(none_);
|
inputs.emplace_back(none_);
|
||||||
|
|
||||||
std::vector<mlir::Type> outputTypes;
|
std::vector<mlir::Type> outputTypes;
|
||||||
for (auto item : node.output()) {
|
for (auto item : node.output()) {
|
||||||
|
// Optional outputs using empty string.
|
||||||
|
if (item.empty())
|
||||||
|
outputTypes.emplace_back(builder_.getNoneType());
|
||||||
|
else
|
||||||
outputTypes.push_back(
|
outputTypes.push_back(
|
||||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
}
|
}
|
||||||
|
// Trailing optional outputs.
|
||||||
|
if (!variadicOut)
|
||||||
|
for (int i = node.output().size(); i < expectedNumResults; ++i)
|
||||||
|
outputTypes.emplace_back(builder_.getNoneType());
|
||||||
|
|
||||||
auto attributes = ImportNodeAttributes(node);
|
auto attributes = ImportNodeAttributes(node);
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,8 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
return !result_type.isa<RankedTensorType>();
|
return !result_type.isa<NoneType>() &&
|
||||||
|
!result_type.isa<RankedTensorType>();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue