Import optional outputs as NoneType (#57)
* Import optional outputs as NoneType * Allow NoneType results after the shape inference * Use empty() to check an empty string
This commit is contained in:
parent
55cbe316fd
commit
c8758545e7
|
@ -248,15 +248,31 @@ private:
|
|||
bool variadicIn = expectedNumOperands == -1;
|
||||
bool variadicOut = expectedNumResults == -1;
|
||||
|
||||
// In ONNX, there are two ways to leave an optional input or output
|
||||
// unspecified: the first, available only for trailing inputs and outputs,
|
||||
// is to simply not provide that input; the second method is to use an empty
|
||||
// string in place of an input or output name.
|
||||
//
|
||||
// Here, we import optional inputs and outputs as NoneType.
|
||||
|
||||
// Trailing optional inputs.
|
||||
if (!variadicIn)
|
||||
for (auto i = inputs.size(); i < expectedNumOperands; i++)
|
||||
inputs.emplace_back(none_);
|
||||
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
for (auto item : node.output()) {
|
||||
// Optional outputs using empty string.
|
||||
if (item.empty())
|
||||
outputTypes.emplace_back(builder_.getNoneType());
|
||||
else
|
||||
outputTypes.push_back(
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||
}
|
||||
// Trailing optional outputs.
|
||||
if (!variadicOut)
|
||||
for (int i = node.output().size(); i < expectedNumResults; ++i)
|
||||
outputTypes.emplace_back(builder_.getNoneType());
|
||||
|
||||
auto attributes = ImportNodeAttributes(node);
|
||||
|
||||
|
|
|
@ -126,7 +126,8 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
return !result_type.isa<RankedTensorType>();
|
||||
return !result_type.isa<NoneType>() &&
|
||||
!result_type.isa<RankedTensorType>();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue