Import optional outputs as NoneType (#57)

* Import optional outputs as NoneType

* Allow NoneType results after the shape inference

* Use empty() to check an empty string
This commit is contained in:
Tung D. Le 2020-03-31 10:21:18 +09:00 committed by GitHub
parent 55cbe316fd
commit c8758545e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 3 deletions

View File

@ -248,15 +248,31 @@ private:
bool variadicIn = expectedNumOperands == -1; bool variadicIn = expectedNumOperands == -1;
bool variadicOut = expectedNumResults == -1; bool variadicOut = expectedNumResults == -1;
// In ONNX, there are two ways to leave an optional input or output
// unspecified: the first, available only for trailing inputs and outputs,
// is to simply not provide that input; the second method is to use an empty
// string in place of an input or output name.
//
// Here, we import optional inputs and outputs as NoneType.
// Trailing optional inputs.
if (!variadicIn) if (!variadicIn)
for (auto i = inputs.size(); i < expectedNumOperands; i++) for (auto i = inputs.size(); i < expectedNumOperands; i++)
inputs.emplace_back(none_); inputs.emplace_back(none_);
std::vector<mlir::Type> outputTypes; std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) { for (auto item : node.output()) {
outputTypes.push_back( // Optional outputs using empty string.
mlir::UnrankedTensorType::get(builder_.getF32Type())); if (item.empty())
outputTypes.emplace_back(builder_.getNoneType());
else
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
} }
// Trailing optional outputs.
if (!variadicOut)
for (int i = node.output().size(); i < expectedNumResults; ++i)
outputTypes.emplace_back(builder_.getNoneType());
auto attributes = ImportNodeAttributes(node); auto attributes = ImportNodeAttributes(node);

View File

@ -126,7 +126,8 @@ public:
op->getName().getStringRef() != "onnx.Unsqueeze") op->getName().getStringRef() != "onnx.Unsqueeze")
return false; return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) { return llvm::any_of(op->getResultTypes(), [](Type result_type) {
return !result_type.isa<RankedTensorType>(); return !result_type.isa<NoneType>() &&
!result_type.isa<RankedTensorType>();
}); });
} }
}; };