a commandline interface for onnf
This commit is contained in:
parent
911cc2ad92
commit
82d513096e
|
@ -745,8 +745,6 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
FrontendGenImpl myONNXGen(context);
|
FrontendGenImpl myONNXGen(context);
|
||||||
auto module = myONNXGen.ImportONNXModel(model);
|
auto module = myONNXGen.ImportONNXModel(model);
|
||||||
module.dump();
|
|
||||||
|
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -761,6 +759,5 @@ void ImportFrontendModelFile(std::string model_fname,
|
||||||
|
|
||||||
FrontendGenImpl myONNXGen(context);
|
FrontendGenImpl myONNXGen(context);
|
||||||
module = myONNXGen.ImportONNXModel(model);
|
module = myONNXGen.ImportONNXModel(model);
|
||||||
module->dump();
|
|
||||||
}
|
}
|
||||||
} // namespace onnf
|
} // namespace onnf
|
||||||
|
|
46
src/main.cpp
46
src/main.cpp
|
@ -77,21 +77,36 @@ int main(int argc, char *argv[]) {
|
||||||
llvm::cl::opt<string> inputFilename(
|
llvm::cl::opt<string> inputFilename(
|
||||||
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
|
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
|
||||||
llvm::cl::cat(OnnfOptions));
|
llvm::cl::cat(OnnfOptions));
|
||||||
|
|
||||||
|
enum EmissionTargetType {
|
||||||
|
EmitONNXIR,
|
||||||
|
EmitMLIR,
|
||||||
|
EmitLLVMIR,
|
||||||
|
EmitLLVMBC,
|
||||||
|
};
|
||||||
|
llvm::cl::opt<EmissionTargetType> emissionTarget(
|
||||||
|
llvm::cl::desc("Choose target to emit:"),
|
||||||
|
llvm::cl::values(
|
||||||
|
clEnumVal(EmitONNXIR, "No optimizations, enable debugging"),
|
||||||
|
clEnumVal(EmitMLIR, "Enable trivial optimizations"),
|
||||||
|
clEnumVal(EmitLLVMIR, "Enable default optimizations"),
|
||||||
|
clEnumVal(EmitLLVMBC, "Enable expensive optimizations")),
|
||||||
|
llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnfOptions));
|
||||||
|
|
||||||
llvm::cl::HideUnrelatedOptions(OnnfOptions);
|
llvm::cl::HideUnrelatedOptions(OnnfOptions);
|
||||||
llvm::cl::ParseCommandLineOptions(argc, argv,
|
llvm::cl::ParseCommandLineOptions(argc, argv,
|
||||||
"ONNF MLIR modular optimizer driver\n");
|
"ONNF MLIR modular optimizer driver\n");
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
|
||||||
mlir::OwningModuleRef module;
|
|
||||||
|
|
||||||
// Decide if the input file is an ONNX model or a model specified
|
// Decide if the input file is an ONNX model or a model specified
|
||||||
// in MLIR. The extension of the file is the decider.
|
// in MLIR. The extension of the file is the decider.
|
||||||
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
|
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
|
||||||
bool inputIsONNX = (extension == "onnx");
|
bool inputIsONNX = (extension == "onnx");
|
||||||
bool inputIsMLIR = (extension == "mlir");
|
bool inputIsMLIR = (extension == "mlir");
|
||||||
|
|
||||||
assert(inputIsONNX != inputIsMLIR &&
|
assert(inputIsONNX != inputIsMLIR &&
|
||||||
"Either ONNX model or MLIR file needs to be provided.");
|
"Either ONNX model or MLIR file needs to be provided.");
|
||||||
|
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
mlir::OwningModuleRef module;
|
||||||
if (inputIsONNX) {
|
if (inputIsONNX) {
|
||||||
ImportFrontendModelFile(inputFilename, context, module);
|
ImportFrontendModelFile(inputFilename, context, module);
|
||||||
} else {
|
} else {
|
||||||
|
@ -101,17 +116,26 @@ int main(int argc, char *argv[]) {
|
||||||
mlir::PassManager pm(&context);
|
mlir::PassManager pm(&context);
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
|
||||||
pm.addPass(mlir::createLowerKrnlPass());
|
if (emissionTarget >= EmitMLIR) {
|
||||||
pm.addPass(mlir::createLowerAffinePass());
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
pm.addPass(mlir::createLowerToCFGPass());
|
pm.addPass(mlir::createLowerKrnlPass());
|
||||||
pm.addPass(mlir::createKrnlLowerToLLVMPass());
|
}
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
|
||||||
|
if (emissionTarget >= EmitLLVMIR) {
|
||||||
|
pm.addPass(mlir::createLowerAffinePass());
|
||||||
|
pm.addPass(mlir::createLowerToCFGPass());
|
||||||
|
pm.addPass(mlir::createKrnlLowerToLLVMPass());
|
||||||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
|
}
|
||||||
|
|
||||||
if (mlir::failed(pm.run(*module)))
|
if (mlir::failed(pm.run(*module)))
|
||||||
return 4;
|
return 4;
|
||||||
|
|
||||||
|
module->dump();
|
||||||
|
|
||||||
// Write LLVM bitcode to disk.
|
// Write LLVM bitcode to disk.
|
||||||
EmitLLVMBitCode(module);
|
if (emissionTarget == EmitLLVMBC)
|
||||||
|
EmitLLVMBitCode(module);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,3 +15,4 @@ target_link_libraries(pyruntime PRIVATE ${CMAKE_DL_LIBS})
|
||||||
target_include_directories(pyruntime
|
target_include_directories(pyruntime
|
||||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
${ONNF_SRC_ROOT})
|
${ONNF_SRC_ROOT})
|
||||||
|
add_dependencies(pyruntime cruntime)
|
||||||
|
|
Loading…
Reference in New Issue