a commandline interface for onnf
This commit is contained in:
		
							parent
							
								
									911cc2ad92
								
							
						
					
					
						commit
						82d513096e
					
				| 
						 | 
					@ -745,8 +745,6 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
 | 
				
			||||||
  mlir::MLIRContext context;
 | 
					  mlir::MLIRContext context;
 | 
				
			||||||
  FrontendGenImpl myONNXGen(context);
 | 
					  FrontendGenImpl myONNXGen(context);
 | 
				
			||||||
  auto module = myONNXGen.ImportONNXModel(model);
 | 
					  auto module = myONNXGen.ImportONNXModel(model);
 | 
				
			||||||
  module.dump();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  return module;
 | 
					  return module;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -761,6 +759,5 @@ void ImportFrontendModelFile(std::string model_fname,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  FrontendGenImpl myONNXGen(context);
 | 
					  FrontendGenImpl myONNXGen(context);
 | 
				
			||||||
  module = myONNXGen.ImportONNXModel(model);
 | 
					  module = myONNXGen.ImportONNXModel(model);
 | 
				
			||||||
  module->dump();
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
}  // namespace onnf
 | 
					}  // namespace onnf
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										46
									
								
								src/main.cpp
								
								
								
								
							
							
						
						
									
										46
									
								
								src/main.cpp
								
								
								
								
							| 
						 | 
					@ -77,21 +77,36 @@ int main(int argc, char *argv[]) {
 | 
				
			||||||
  llvm::cl::opt<string> inputFilename(
 | 
					  llvm::cl::opt<string> inputFilename(
 | 
				
			||||||
      llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
 | 
					      llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
 | 
				
			||||||
      llvm::cl::cat(OnnfOptions));
 | 
					      llvm::cl::cat(OnnfOptions));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  enum EmissionTargetType {
 | 
				
			||||||
 | 
					    EmitONNXIR,
 | 
				
			||||||
 | 
					    EmitMLIR,
 | 
				
			||||||
 | 
					    EmitLLVMIR,
 | 
				
			||||||
 | 
					    EmitLLVMBC,
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					  llvm::cl::opt<EmissionTargetType> emissionTarget(
 | 
				
			||||||
 | 
					      llvm::cl::desc("Choose target to emit:"),
 | 
				
			||||||
 | 
					      llvm::cl::values(
 | 
				
			||||||
 | 
					          clEnumVal(EmitONNXIR, "No optimizations, enable debugging"),
 | 
				
			||||||
 | 
					          clEnumVal(EmitMLIR, "Enable trivial optimizations"),
 | 
				
			||||||
 | 
					          clEnumVal(EmitLLVMIR, "Enable default optimizations"),
 | 
				
			||||||
 | 
					          clEnumVal(EmitLLVMBC, "Enable expensive optimizations")),
 | 
				
			||||||
 | 
					      llvm::cl::init(EmitLLVMBC), llvm::cl::cat(OnnfOptions));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  llvm::cl::HideUnrelatedOptions(OnnfOptions);
 | 
					  llvm::cl::HideUnrelatedOptions(OnnfOptions);
 | 
				
			||||||
  llvm::cl::ParseCommandLineOptions(argc, argv,
 | 
					  llvm::cl::ParseCommandLineOptions(argc, argv,
 | 
				
			||||||
                                    "ONNF MLIR modular optimizer driver\n");
 | 
					                                    "ONNF MLIR modular optimizer driver\n");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  mlir::MLIRContext context;
 | 
					 | 
				
			||||||
  mlir::OwningModuleRef module;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Decide if the input file is an ONNX model or a model specified
 | 
					  // Decide if the input file is an ONNX model or a model specified
 | 
				
			||||||
  // in MLIR. The extension of the file is the decider.
 | 
					  // in MLIR. The extension of the file is the decider.
 | 
				
			||||||
  string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
 | 
					  string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
 | 
				
			||||||
  bool inputIsONNX = (extension == "onnx");
 | 
					  bool inputIsONNX = (extension == "onnx");
 | 
				
			||||||
  bool inputIsMLIR = (extension == "mlir");
 | 
					  bool inputIsMLIR = (extension == "mlir");
 | 
				
			||||||
 | 
					 | 
				
			||||||
  assert(inputIsONNX != inputIsMLIR &&
 | 
					  assert(inputIsONNX != inputIsMLIR &&
 | 
				
			||||||
         "Either ONNX model or MLIR file needs to be provided.");
 | 
					         "Either ONNX model or MLIR file needs to be provided.");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  mlir::MLIRContext context;
 | 
				
			||||||
 | 
					  mlir::OwningModuleRef module;
 | 
				
			||||||
  if (inputIsONNX) {
 | 
					  if (inputIsONNX) {
 | 
				
			||||||
    ImportFrontendModelFile(inputFilename, context, module);
 | 
					    ImportFrontendModelFile(inputFilename, context, module);
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
| 
						 | 
					@ -101,17 +116,26 @@ int main(int argc, char *argv[]) {
 | 
				
			||||||
  mlir::PassManager pm(&context);
 | 
					  mlir::PassManager pm(&context);
 | 
				
			||||||
  pm.addPass(mlir::createShapeInferencePass());
 | 
					  pm.addPass(mlir::createShapeInferencePass());
 | 
				
			||||||
  pm.addPass(mlir::createCanonicalizerPass());
 | 
					  pm.addPass(mlir::createCanonicalizerPass());
 | 
				
			||||||
  pm.addPass(mlir::createLowerToKrnlPass());
 | 
					
 | 
				
			||||||
  pm.addPass(mlir::createLowerKrnlPass());
 | 
					  if (emissionTarget >= EmitMLIR) {
 | 
				
			||||||
  pm.addPass(mlir::createLowerAffinePass());
 | 
					    pm.addPass(mlir::createLowerToKrnlPass());
 | 
				
			||||||
  pm.addPass(mlir::createLowerToCFGPass());
 | 
					    pm.addPass(mlir::createLowerKrnlPass());
 | 
				
			||||||
  pm.addPass(mlir::createKrnlLowerToLLVMPass());
 | 
					  }
 | 
				
			||||||
  pm.addPass(mlir::createCanonicalizerPass());
 | 
					
 | 
				
			||||||
 | 
					  if (emissionTarget >= EmitLLVMIR) {
 | 
				
			||||||
 | 
					    pm.addPass(mlir::createLowerAffinePass());
 | 
				
			||||||
 | 
					    pm.addPass(mlir::createLowerToCFGPass());
 | 
				
			||||||
 | 
					    pm.addPass(mlir::createKrnlLowerToLLVMPass());
 | 
				
			||||||
 | 
					    pm.addPass(mlir::createCanonicalizerPass());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (mlir::failed(pm.run(*module)))
 | 
					  if (mlir::failed(pm.run(*module)))
 | 
				
			||||||
    return 4;
 | 
					    return 4;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  module->dump();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Write LLVM bitcode to disk.
 | 
					  // Write LLVM bitcode to disk.
 | 
				
			||||||
  EmitLLVMBitCode(module);
 | 
					  if (emissionTarget == EmitLLVMBC)
 | 
				
			||||||
 | 
					    EmitLLVMBitCode(module);
 | 
				
			||||||
  return 0;
 | 
					  return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,3 +15,4 @@ target_link_libraries(pyruntime PRIVATE ${CMAKE_DL_LIBS})
 | 
				
			||||||
target_include_directories(pyruntime
 | 
					target_include_directories(pyruntime
 | 
				
			||||||
        PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
 | 
					        PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
 | 
				
			||||||
        ${ONNF_SRC_ROOT})
 | 
					        ${ONNF_SRC_ROOT})
 | 
				
			||||||
 | 
					add_dependencies(pyruntime cruntime)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue