a commandline interface for onnf

This commit is contained in:
Tian Jin 2019-12-22 23:52:49 -05:00
parent 911cc2ad92
commit 82d513096e
3 changed files with 36 additions and 14 deletions

View File

@ -745,8 +745,6 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
mlir::MLIRContext context; mlir::MLIRContext context;
FrontendGenImpl myONNXGen(context); FrontendGenImpl myONNXGen(context);
auto module = myONNXGen.ImportONNXModel(model); auto module = myONNXGen.ImportONNXModel(model);
module.dump();
return module; return module;
} }
@ -761,6 +759,5 @@ void ImportFrontendModelFile(std::string model_fname,
FrontendGenImpl myONNXGen(context); FrontendGenImpl myONNXGen(context);
module = myONNXGen.ImportONNXModel(model); module = myONNXGen.ImportONNXModel(model);
module->dump();
} }
} // namespace onnf } // namespace onnf

View File

@ -77,21 +77,36 @@ int main(int argc, char *argv[]) {
llvm::cl::opt<string> inputFilename( llvm::cl::opt<string> inputFilename(
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"), llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::cat(OnnfOptions)); llvm::cl::cat(OnnfOptions));
enum EmissionTargetType {
EmitONNXIR,
EmitMLIR,
EmitLLVMIR,
EmitLLVMBC,
};
llvm::cl::opt<EmissionTargetType> emissionTarget(
llvm::cl::desc("Choose target to emit:"),
llvm::cl::values(
clEnumVal(EmitONNXIR, "No optimizations, enable debugging"),
clEnumVal(EmitMLIR, "Enable trivial optimizations"),
clEnumVal(EmitLLVMIR, "Enable default optimizations"),
clEnumVal(EmitLLVMBC, "Enable expensive optimizations")),
llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnfOptions));
llvm::cl::HideUnrelatedOptions(OnnfOptions); llvm::cl::HideUnrelatedOptions(OnnfOptions);
llvm::cl::ParseCommandLineOptions(argc, argv, llvm::cl::ParseCommandLineOptions(argc, argv,
"ONNF MLIR modular optimizer driver\n"); "ONNF MLIR modular optimizer driver\n");
mlir::MLIRContext context;
mlir::OwningModuleRef module;
// Decide if the input file is an ONNX model or a model specified // Decide if the input file is an ONNX model or a model specified
// in MLIR. The extension of the file is the decider. // in MLIR. The extension of the file is the decider.
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
bool inputIsONNX = (extension == "onnx"); bool inputIsONNX = (extension == "onnx");
bool inputIsMLIR = (extension == "mlir"); bool inputIsMLIR = (extension == "mlir");
assert(inputIsONNX != inputIsMLIR && assert(inputIsONNX != inputIsMLIR &&
"Either ONNX model or MLIR file needs to be provided."); "Either ONNX model or MLIR file needs to be provided.");
mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (inputIsONNX) { if (inputIsONNX) {
ImportFrontendModelFile(inputFilename, context, module); ImportFrontendModelFile(inputFilename, context, module);
} else { } else {
@ -101,17 +116,26 @@ int main(int argc, char *argv[]) {
mlir::PassManager pm(&context); mlir::PassManager pm(&context);
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createLowerToKrnlPass());
pm.addPass(mlir::createLowerKrnlPass()); if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerAffinePass()); pm.addPass(mlir::createLowerToKrnlPass());
pm.addPass(mlir::createLowerToCFGPass()); pm.addPass(mlir::createLowerKrnlPass());
pm.addPass(mlir::createKrnlLowerToLLVMPass()); }
pm.addPass(mlir::createCanonicalizerPass());
if (emissionTarget >= EmitLLVMIR) {
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(mlir::createKrnlLowerToLLVMPass());
pm.addPass(mlir::createCanonicalizerPass());
}
if (mlir::failed(pm.run(*module))) if (mlir::failed(pm.run(*module)))
return 4; return 4;
module->dump();
// Write LLVM bitcode to disk. // Write LLVM bitcode to disk.
EmitLLVMBitCode(module); if (emissionTarget == EmitLLVMBC)
EmitLLVMBitCode(module);
return 0; return 0;
} }

View File

@ -15,3 +15,4 @@ target_link_libraries(pyruntime PRIVATE ${CMAKE_DL_LIBS})
target_include_directories(pyruntime target_include_directories(pyruntime
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT}) ${ONNF_SRC_ROOT})
add_dependencies(pyruntime cruntime)