diff --git a/src/main.cpp b/src/main.cpp index 83d7a56..078f229 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -32,6 +32,8 @@ #include "mlir/Target/LLVMIR.h" #include "mlir/Transforms/Passes.h" +void EmitLLVMBitCode(const mlir::OwningModuleRef &module); + using namespace std; using namespace onnf; @@ -57,13 +59,22 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context, } } +void EmitLLVMBitCode(const mlir::OwningModuleRef &module) { + error_code error; + llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error, + llvm::sys::fs::F_None); + llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), + moduleBitcodeStream); + moduleBitcodeStream.flush(); +} + int main(int argc, char *argv[]) { mlir::registerDialect(); mlir::registerDialect(); llvm::cl::OptionCategory OnnfOptions("ONNF Options", "These are frontend options."); - llvm::cl::opt InputFilename( + llvm::cl::opt inputFilename( llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::cat(OnnfOptions)); llvm::cl::HideUnrelatedOptions(OnnfOptions); @@ -75,17 +86,16 @@ int main(int argc, char *argv[]) { // Decide if the input file is an ONNX model or a model specified // in MLIR. The extension of the file is the decider. - string extension = - InputFilename.substr(InputFilename.find_last_of(".") + 1); - bool onnx_model_provided = (extension == "onnx"); - bool mlir_model_provided = (extension == "mlir"); + string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); + bool inputIsONNX = (extension == "onnx"); + bool inputIsMLIR = (extension == "mlir"); - if (onnx_model_provided) { - ImportFrontendModelFile(InputFilename, context, module); - } else if (mlir_model_provided) { - LoadMLIR(InputFilename, context, module); + assert(inputIsONNX != inputIsMLIR && + "Either ONNX model or MLIR file needs to be provided."); + if (inputIsONNX) { + ImportFrontendModelFile(inputFilename, context, module); } else { - assert(false && "No ONNX or MLIR models provided!"); + LoadMLIR(inputFilename, context, module); } mlir::PassManager pm(&context); @@ -97,15 +107,11 @@ int main(int argc, char *argv[]) { pm.addPass(mlir::createLowerToCFGPass()); pm.addPass(mlir::createKrnlLowerToLLVMPass()); pm.addPass(mlir::createCanonicalizerPass()); - pm.run(*module); + + if (mlir::failed(pm.run(*module))) + return 4; // Write LLVM bitcode to disk. - std::error_code EC; - llvm::raw_fd_ostream moduleBitcodeStream("model.bc", EC, - llvm::sys::fs::F_None); - llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), - moduleBitcodeStream); - moduleBitcodeStream.flush(); - + EmitLLVMBitCode(module); return 0; }