a commandline interface for onnf
This commit is contained in:
		
							parent
							
								
									911cc2ad92
								
							
						
					
					
						commit
						82d513096e
					
				| 
						 | 
				
			
			@ -745,8 +745,6 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
 | 
			
		|||
  mlir::MLIRContext context;
 | 
			
		||||
  FrontendGenImpl myONNXGen(context);
 | 
			
		||||
  auto module = myONNXGen.ImportONNXModel(model);
 | 
			
		||||
  module.dump();
 | 
			
		||||
 | 
			
		||||
  return module;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -761,6 +759,5 @@ void ImportFrontendModelFile(std::string model_fname,
 | 
			
		|||
 | 
			
		||||
  FrontendGenImpl myONNXGen(context);
 | 
			
		||||
  module = myONNXGen.ImportONNXModel(model);
 | 
			
		||||
  module->dump();
 | 
			
		||||
}
 | 
			
		||||
}  // namespace onnf
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										46
									
								
								src/main.cpp
								
								
								
								
							
							
						
						
									
										46
									
								
								src/main.cpp
								
								
								
								
							| 
						 | 
				
			
			@ -77,21 +77,36 @@ int main(int argc, char *argv[]) {
 | 
			
		|||
  llvm::cl::opt<string> inputFilename(
 | 
			
		||||
      llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
 | 
			
		||||
      llvm::cl::cat(OnnfOptions));
 | 
			
		||||
 | 
			
		||||
  enum EmissionTargetType {
 | 
			
		||||
    EmitONNXIR,
 | 
			
		||||
    EmitMLIR,
 | 
			
		||||
    EmitLLVMIR,
 | 
			
		||||
    EmitLLVMBC,
 | 
			
		||||
  };
 | 
			
		||||
  llvm::cl::opt<EmissionTargetType> emissionTarget(
 | 
			
		||||
      llvm::cl::desc("Choose target to emit:"),
 | 
			
		||||
      llvm::cl::values(
 | 
			
		||||
          clEnumVal(EmitONNXIR, "No optimizations, enable debugging"),
 | 
			
		||||
          clEnumVal(EmitMLIR, "Enable trivial optimizations"),
 | 
			
		||||
          clEnumVal(EmitLLVMIR, "Enable default optimizations"),
 | 
			
		||||
          clEnumVal(EmitLLVMBC, "Enable expensive optimizations")),
 | 
			
		||||
      llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnfOptions));
 | 
			
		||||
 | 
			
		||||
  llvm::cl::HideUnrelatedOptions(OnnfOptions);
 | 
			
		||||
  llvm::cl::ParseCommandLineOptions(argc, argv,
 | 
			
		||||
                                    "ONNF MLIR modular optimizer driver\n");
 | 
			
		||||
 | 
			
		||||
  mlir::MLIRContext context;
 | 
			
		||||
  mlir::OwningModuleRef module;
 | 
			
		||||
 | 
			
		||||
  // Decide if the input file is an ONNX model or a model specified
 | 
			
		||||
  // in MLIR. The extension of the file is the decider.
 | 
			
		||||
  string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
 | 
			
		||||
  bool inputIsONNX = (extension == "onnx");
 | 
			
		||||
  bool inputIsMLIR = (extension == "mlir");
 | 
			
		||||
 | 
			
		||||
  assert(inputIsONNX != inputIsMLIR &&
 | 
			
		||||
         "Either ONNX model or MLIR file needs to be provided.");
 | 
			
		||||
 | 
			
		||||
  mlir::MLIRContext context;
 | 
			
		||||
  mlir::OwningModuleRef module;
 | 
			
		||||
  if (inputIsONNX) {
 | 
			
		||||
    ImportFrontendModelFile(inputFilename, context, module);
 | 
			
		||||
  } else {
 | 
			
		||||
| 
						 | 
				
			
			@ -101,17 +116,26 @@ int main(int argc, char *argv[]) {
 | 
			
		|||
  mlir::PassManager pm(&context);
 | 
			
		||||
  pm.addPass(mlir::createShapeInferencePass());
 | 
			
		||||
  pm.addPass(mlir::createCanonicalizerPass());
 | 
			
		||||
  pm.addPass(mlir::createLowerToKrnlPass());
 | 
			
		||||
  pm.addPass(mlir::createLowerKrnlPass());
 | 
			
		||||
  pm.addPass(mlir::createLowerAffinePass());
 | 
			
		||||
  pm.addPass(mlir::createLowerToCFGPass());
 | 
			
		||||
  pm.addPass(mlir::createKrnlLowerToLLVMPass());
 | 
			
		||||
  pm.addPass(mlir::createCanonicalizerPass());
 | 
			
		||||
 | 
			
		||||
  if (emissionTarget >= EmitMLIR) {
 | 
			
		||||
    pm.addPass(mlir::createLowerToKrnlPass());
 | 
			
		||||
    pm.addPass(mlir::createLowerKrnlPass());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (emissionTarget >= EmitLLVMIR) {
 | 
			
		||||
    pm.addPass(mlir::createLowerAffinePass());
 | 
			
		||||
    pm.addPass(mlir::createLowerToCFGPass());
 | 
			
		||||
    pm.addPass(mlir::createKrnlLowerToLLVMPass());
 | 
			
		||||
    pm.addPass(mlir::createCanonicalizerPass());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (mlir::failed(pm.run(*module)))
 | 
			
		||||
    return 4;
 | 
			
		||||
 | 
			
		||||
  module->dump();
 | 
			
		||||
 | 
			
		||||
  // Write LLVM bitcode to disk.
 | 
			
		||||
  EmitLLVMBitCode(module);
 | 
			
		||||
  if (emissionTarget == EmitLLVMBC)
 | 
			
		||||
    EmitLLVMBitCode(module);
 | 
			
		||||
  return 0;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,3 +15,4 @@ target_link_libraries(pyruntime PRIVATE ${CMAKE_DL_LIBS})
 | 
			
		|||
target_include_directories(pyruntime
 | 
			
		||||
        PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
 | 
			
		||||
        ${ONNF_SRC_ROOT})
 | 
			
		||||
add_dependencies(pyruntime cruntime)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue