diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 05c079a..3547f1d 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -745,8 +745,6 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) { mlir::MLIRContext context; FrontendGenImpl myONNXGen(context); auto module = myONNXGen.ImportONNXModel(model); - module.dump(); - return module; } @@ -761,6 +759,5 @@ void ImportFrontendModelFile(std::string model_fname, FrontendGenImpl myONNXGen(context); module = myONNXGen.ImportONNXModel(model); - module->dump(); } } // namespace onnf diff --git a/src/main.cpp b/src/main.cpp index 078f229..eb4b8e4 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -77,21 +77,36 @@ int main(int argc, char *argv[]) { llvm::cl::opt inputFilename( llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::cat(OnnfOptions)); + + enum EmissionTargetType { + EmitONNXIR, + EmitMLIR, + EmitLLVMIR, + EmitLLVMBC, + }; + llvm::cl::opt emissionTarget( + llvm::cl::desc("Choose target to emit:"), + llvm::cl::values( + clEnumVal(EmitONNXIR, "No optimizations, enable debugging"), + clEnumVal(EmitMLIR, "Enable trivial optimizations"), + clEnumVal(EmitLLVMIR, "Enable default optimizations"), + clEnumVal(EmitLLVMBC, "Enable expensive optimizations")), + llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnfOptions)); + llvm::cl::HideUnrelatedOptions(OnnfOptions); llvm::cl::ParseCommandLineOptions(argc, argv, "ONNF MLIR modular optimizer driver\n"); - mlir::MLIRContext context; - mlir::OwningModuleRef module; - // Decide if the input file is an ONNX model or a model specified // in MLIR. The extension of the file is the decider. string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); bool inputIsONNX = (extension == "onnx"); bool inputIsMLIR = (extension == "mlir"); - assert(inputIsONNX != inputIsMLIR && "Either ONNX model or MLIR file needs to be provided."); + + mlir::MLIRContext context; + mlir::OwningModuleRef module; if (inputIsONNX) { ImportFrontendModelFile(inputFilename, context, module); } else { @@ -101,17 +116,26 @@ int main(int argc, char *argv[]) { mlir::PassManager pm(&context); pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createLowerToKrnlPass()); - pm.addPass(mlir::createLowerKrnlPass()); - pm.addPass(mlir::createLowerAffinePass()); - pm.addPass(mlir::createLowerToCFGPass()); - pm.addPass(mlir::createKrnlLowerToLLVMPass()); - pm.addPass(mlir::createCanonicalizerPass()); + + if (emissionTarget >= EmitMLIR) { + pm.addPass(mlir::createLowerToKrnlPass()); + pm.addPass(mlir::createLowerKrnlPass()); + } + + if (emissionTarget >= EmitLLVMIR) { + pm.addPass(mlir::createLowerAffinePass()); + pm.addPass(mlir::createLowerToCFGPass()); + pm.addPass(mlir::createKrnlLowerToLLVMPass()); + pm.addPass(mlir::createCanonicalizerPass()); + } if (mlir::failed(pm.run(*module))) return 4; + module->dump(); + // Write LLVM bitcode to disk. - EmitLLVMBitCode(module); + if (emissionTarget == EmitLLVMBC) + EmitLLVMBitCode(module); return 0; } diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt index 8522310..040ae01 100644 --- a/src/runtime/CMakeLists.txt +++ b/src/runtime/CMakeLists.txt @@ -15,3 +15,4 @@ target_link_libraries(pyruntime PRIVATE ${CMAKE_DL_LIBS}) target_include_directories(pyruntime PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} ${ONNF_SRC_ROOT}) +add_dependencies(pyruntime cruntime)