Import optional outputs as NoneType (#57)
* Import optional outputs as NoneType * Allow NoneType results after the shape inference * Use empty() to check an empty string
This commit is contained in:
		
							parent
							
								
									55cbe316fd
								
							
						
					
					
						commit
						c8758545e7
					
				|  | @ -248,15 +248,31 @@ private: | |||
|     bool variadicIn = expectedNumOperands == -1; | ||||
|     bool variadicOut = expectedNumResults == -1; | ||||
| 
 | ||||
|     // In ONNX, there are two ways to leave an optional input or output
 | ||||
|     // unspecified: the first, available only for trailing inputs and outputs,
 | ||||
|     // is to simply not provide that input; the second method is to use an empty
 | ||||
|     // string in place of an input or output name.
 | ||||
|     //
 | ||||
|     // Here, we import optional inputs and outputs as NoneType.
 | ||||
| 
 | ||||
|     // Trailing optional inputs.
 | ||||
|     if (!variadicIn) | ||||
|       for (auto i = inputs.size(); i < expectedNumOperands; i++) | ||||
|         inputs.emplace_back(none_); | ||||
| 
 | ||||
|     std::vector<mlir::Type> outputTypes; | ||||
|     for (auto item : node.output()) { | ||||
|       // Optional outputs using empty string.
 | ||||
|       if (item.empty()) | ||||
|         outputTypes.emplace_back(builder_.getNoneType()); | ||||
|       else | ||||
|         outputTypes.push_back( | ||||
|             mlir::UnrankedTensorType::get(builder_.getF32Type())); | ||||
|     } | ||||
|     // Trailing optional outputs.
 | ||||
|     if (!variadicOut) | ||||
|       for (int i = node.output().size(); i < expectedNumResults; ++i) | ||||
|         outputTypes.emplace_back(builder_.getNoneType()); | ||||
| 
 | ||||
|     auto attributes = ImportNodeAttributes(node); | ||||
| 
 | ||||
|  |  | |||
|  | @ -126,7 +126,8 @@ public: | |||
|         op->getName().getStringRef() != "onnx.Unsqueeze") | ||||
|       return false; | ||||
|     return llvm::any_of(op->getResultTypes(), [](Type result_type) { | ||||
|       return !result_type.isa<RankedTensorType>(); | ||||
|       return !result_type.isa<NoneType>() && | ||||
|              !result_type.isa<RankedTensorType>(); | ||||
|     }); | ||||
|   } | ||||
| }; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue