//===--------------------------- main_utils.cpp ---------------------------===// // // Copyright 2019-2020 The IBM Research Authors. // // ============================================================================= // // Functions for adding passes and processing input files. // //===----------------------------------------------------------------------===// #include "src/MainUtils.hpp" #include #include #include using namespace std; using namespace onnx_mlir; void LoadMLIR(string inputFilename, mlir::MLIRContext &context, mlir::OwningModuleRef &module) { // Handle '.mlir' input to the ONNX MLIR frontend. // The mlir format indicates that one or more of the supported // representations are used in the file. llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); if (std::error_code EC = fileOrErr.getError()) { llvm::errs() << "Could not open input file: " << EC.message() << "\n"; return; } // Parse the input mlir. llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); module = mlir::parseSourceFile(sourceMgr, &context); if (!module) { llvm::errs() << "Error can't load file " << inputFilename << "\n"; return; } } void EmitLLVMBitCode( const mlir::OwningModuleRef &module, string outputFilename) { error_code error; llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error, llvm::sys::fs::F_None); llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); moduleBitcodeStream.flush(); } void registerDialects() { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); } void addONNXToMLIRPasses(mlir::PassManager &pm) { pm.addPass(mlir::createDecomposeONNXToONNXPass()); pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createAttributePromotionPass()); } void addONNXToKrnlPasses(mlir::PassManager &pm) { pm.addPass(mlir::createLowerToKrnlPass()); // An additional pass of canonicalization is helpful because lowering // from ONNX dialect to Standard dialect exposes additional canonicalization // oppertunities. pm.addPass(mlir::createCanonicalizerPass()); } void addKrnlToAffinePasses(mlir::PassManager &pm) { pm.addPass(mlir::createLowerKrnlPass()); } void addKrnlToLLVMPasses(mlir::PassManager &pm) { pm.addPass(mlir::createLowerAffinePass()); pm.addPass(mlir::createLowerToCFGPass()); pm.addPass(mlir::createKrnlLowerToLLVMPass()); pm.addPass(mlir::createCanonicalizerPass()); } void processInputFile(string inputFilename, EmissionTargetType emissionTarget, mlir::MLIRContext &context, mlir::OwningModuleRef &module) { // Decide if the input file is an ONNX model or a model specified // in MLIR. The extension of the file is the decider. string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); bool inputIsONNX = (extension == "onnx"); bool inputIsMLIR = (extension == "mlir"); assert(inputIsONNX != inputIsMLIR && "Either ONNX model or MLIR file needs to be provided."); if (inputIsONNX) { ImportFrontendModelFile(inputFilename, context, module); } else { LoadMLIR(inputFilename, context, module); } } void outputCode( mlir::OwningModuleRef &module, string filename, string extension) { // Start a separate process to redirect the model output. I/O redirection // changes will not be visible to the parent process. if (fork() == 0) { const char * tempFilename = (filename + extension).c_str(); freopen(tempFilename, "w", stderr); module->dump(); fclose(stderr); exit(0); } } void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget, mlir::MLIRContext &context, mlir::OwningModuleRef &module) { // For EmitONNXIR and EmitMLIR the constant value are embedded in the code // thus making the code hard to read. These values can be elided by emitting // two versions of the same source code: // (1) a version with all the constant values included meant for being passed // back to onnx-mlir for further processing and stored in: // // .onnx.mlir // // (2) a version without constants meant for being inspected by users and // stored in: // // .mlir // // In the case of the LLVM Dialect IR the constant values are grouped // outside the function code at the beginning of the file in which case the // elision of these constants is not strictly required. Elision is also not // necessary when emitting the .bc file. if (emissionTarget == EmitLLVMBC) { // Write LLVM bitcode to disk. string outputFilename = outputBaseName + ".bc"; EmitLLVMBitCode(module, outputFilename); printf("LLVM bitcode written to %s\n", outputFilename.c_str()); } else { // Emit the version with all constants included. outputCode(module, outputBaseName, ".onnx.mlir"); printf("Full MLIR code written to: \n\t%s\n\n", (outputBaseName + ".onnx.mlir").c_str()); // Apply specific passes to clean up the code where necessary. mlir::PassManager cleanSourcePM(&context); if (emissionTarget == EmitONNXIR || emissionTarget == EmitONNXBasic) cleanSourcePM.addPass(mlir::createElideConstantValuePass()); if (emissionTarget == EmitMLIR) cleanSourcePM.addPass(mlir::createElideConstGlobalValuePass()); if (emissionTarget == EmitONNXBasic || emissionTarget == EmitONNXIR || emissionTarget == EmitMLIR) { if (mlir::failed(cleanSourcePM.run(*module))) llvm::errs() << "Could not apply simplification passes.\n"; outputCode(module, outputBaseName, ".mlir"); printf("Constant-free MLIR Code written to: \n\t%s\n\n", (outputBaseName + ".mlir").c_str()); printf("Use:\n\t%s\nto continue lowering the code to other dialects.\n", (outputBaseName + ".onnx.mlir").c_str()); } } }