diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 91d25a6..02990c4 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -248,15 +248,31 @@ private: bool variadicIn = expectedNumOperands == -1; bool variadicOut = expectedNumResults == -1; + // In ONNX, there are two ways to leave an optional input or output + // unspecified: the first, available only for trailing inputs and outputs, + // is to simply not provide that input; the second method is to use an empty + // string in place of an input or output name. + // + // Here, we import optional inputs and outputs as NoneType. + + // Trailing optional inputs. if (!variadicIn) for (auto i = inputs.size(); i < expectedNumOperands; i++) inputs.emplace_back(none_); std::vector outputTypes; for (auto item : node.output()) { - outputTypes.push_back( - mlir::UnrankedTensorType::get(builder_.getF32Type())); + // Optional outputs using empty string. + if (item.empty()) + outputTypes.emplace_back(builder_.getNoneType()); + else + outputTypes.push_back( + mlir::UnrankedTensorType::get(builder_.getF32Type())); } + // Trailing optional outputs. + if (!variadicOut) + for (int i = node.output().size(); i < expectedNumResults; ++i) + outputTypes.emplace_back(builder_.getNoneType()); auto attributes = ImportNodeAttributes(node); diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index ab05931..1bf9855 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -126,7 +126,8 @@ public: op->getName().getStringRef() != "onnx.Unsqueeze") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { - return !result_type.isa(); + return !result_type.isa() && + !result_type.isa(); }); } };