From 1777c07b1eb876f588f50a501ad47d954c0e9ecd Mon Sep 17 00:00:00 2001 From: Gheorghe-Teodor Bercea Date: Tue, 24 Mar 2020 13:48:54 -0400 Subject: [PATCH] [NFC] Reorganize main function. (#44) * Reorganize main function. * Follow review comments. * Use new file names. --- src/CMakeLists.txt | 5 ++- src/MainUtils.cpp | 99 ++++++++++++++++++++++++++++++++++++++++ src/MainUtils.hpp | 64 ++++++++++++++++++++++++++ src/main.cpp | 110 ++++----------------------------------------- 4 files changed, 175 insertions(+), 103 deletions(-) create mode 100644 src/MainUtils.cpp create mode 100644 src/MainUtils.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 697aaa4..f05d078 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -6,7 +6,10 @@ add_subdirectory(Tool) add_subdirectory(Builder) add_subdirectory(Runtime) -add_executable(onnx-mlir main.cpp) +add_executable(onnx-mlir + MainUtils.hpp + MainUtils.cpp + main.cpp) target_link_libraries(onnx-mlir ${MLIRLibs} OMBuilder diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp new file mode 100644 index 0000000..57bf4f0 --- /dev/null +++ b/src/MainUtils.cpp @@ -0,0 +1,99 @@ +//===--------------------------- main_utils.cpp ---------------------------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// Functions for adding passes and processing input files. +// +//===----------------------------------------------------------------------===// + +#include "src/MainUtils.hpp" + +using namespace std; +using namespace onnx_mlir; + +void LoadMLIR(string inputFilename, mlir::MLIRContext &context, + mlir::OwningModuleRef &module) { + // Handle '.mlir' input to the ONNX MLIR frontend. + // The mlir format indicates that one or more of the supported + // representations are used in the file. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return; + } +} + +void EmitLLVMBitCode(const mlir::OwningModuleRef &module) { + error_code error; + llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error, + llvm::sys::fs::F_None); + llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), + moduleBitcodeStream); + moduleBitcodeStream.flush(); +} + +void registerDialects() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); +} + +void addONNXToMLIRPasses(mlir::PassManager &pm) { + pm.addPass(mlir::createDecomposeONNXToONNXPass()); + pm.addPass(mlir::createShapeInferencePass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createShapeInferencePass()); + pm.addPass(mlir::createAttributePromotionPass()); +} + +void addONNXToKrnlPasses(mlir::PassManager &pm) { + pm.addPass(mlir::createLowerToKrnlPass()); + // An additional pass of canonicalization is helpful because lowering + // from ONNX dialect to Standard dialect exposes additional canonicalization + // oppertunities. + pm.addPass(mlir::createCanonicalizerPass()); +} + +void addKrnlToAffinePasses(mlir::PassManager &pm) { + pm.addPass(mlir::createLowerKrnlPass()); +} + +void addKrnlToLLVMPasses(mlir::PassManager &pm) { + pm.addPass(mlir::createLowerAffinePass()); + pm.addPass(mlir::createLowerToCFGPass()); + pm.addPass(mlir::createKrnlLowerToLLVMPass()); + pm.addPass(mlir::createCanonicalizerPass()); +} + +void processInputFile(string inputFilename, EmissionTargetType emissionTarget, + mlir::MLIRContext &context, mlir::OwningModuleRef &module) { + // Decide if the input file is an ONNX model or a model specified + // in MLIR. The extension of the file is the decider. + string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); + bool inputIsONNX = (extension == "onnx"); + bool inputIsMLIR = (extension == "mlir"); + assert(inputIsONNX != inputIsMLIR && + "Either ONNX model or MLIR file needs to be provided."); + + + if (inputIsONNX) { + ImportFrontendModelFile(inputFilename, context, module); + } else { + LoadMLIR(inputFilename, context, module); + } +} diff --git a/src/MainUtils.hpp b/src/MainUtils.hpp new file mode 100644 index 0000000..eef6521 --- /dev/null +++ b/src/MainUtils.hpp @@ -0,0 +1,64 @@ +//===--------------------------- main_utils.hpp ---------------------------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// Functions for adding passes and processing input files. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/SourceMgr.h" + +#include "src/Builder/FrontendDialectTransformer.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Pass/Passes.hpp" + +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/InitAllDialects.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" +#include "mlir/Transforms/Passes.h" + +enum EmissionTargetType { + EmitONNXIR, + EmitMLIR, + EmitLLVMIR, + EmitLLVMBC, +}; + +void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context, + mlir::OwningModuleRef &module); + +void EmitLLVMBitCode(const mlir::OwningModuleRef &module); + +void registerDialects(); + +void addONNXToMLIRPasses(mlir::PassManager &pm); + +void addONNXToKrnlPasses(mlir::PassManager &pm); + +void addKrnlToAffinePasses(mlir::PassManager &pm); + +void addKrnlToLLVMPasses(mlir::PassManager &pm); + +void processInputFile(std::string inputFilename, + EmissionTargetType emissionTarget, mlir::MLIRContext &context, + mlir::OwningModuleRef &module); \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 33b6595..1dacbff 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -6,76 +6,13 @@ // //===----------------------------------------------------------------------===// -#include -#include - -#include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/FileUtilities.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/Regex.h" -#include "llvm/Support/SourceMgr.h" - -#include "src/Builder/FrontendDialectTransformer.hpp" -#include "src/Dialect/Krnl/KrnlOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" -#include "src/Pass/Passes.hpp" - -#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" -#include "mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/ExecutionEngine/OptUtils.h" -#include "mlir/InitAllDialects.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/Parser.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Target/LLVMIR.h" -#include "mlir/Transforms/Passes.h" - -void EmitLLVMBitCode(const mlir::OwningModuleRef &module); +#include "src/MainUtils.hpp" using namespace std; using namespace onnx_mlir; -void LoadMLIR(string inputFilename, mlir::MLIRContext &context, - mlir::OwningModuleRef &module) { - // Handle '.mlir' input to the ONNX MLIR frontend. - // The mlir format indicates that one or more of the supported - // representations are used in the file. - llvm::ErrorOr> fileOrErr = - llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); - if (std::error_code EC = fileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << EC.message() << "\n"; - return; - } - - // Parse the input mlir. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module = mlir::parseSourceFile(sourceMgr, &context); - if (!module) { - llvm::errs() << "Error can't load file " << inputFilename << "\n"; - return; - } -} - -void EmitLLVMBitCode(const mlir::OwningModuleRef &module) { - error_code error; - llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error, - llvm::sys::fs::F_None); - llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), - moduleBitcodeStream); - moduleBitcodeStream.flush(); -} - int main(int argc, char *argv[]) { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); + registerDialects(); llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options", "These are frontend options."); @@ -83,12 +20,6 @@ int main(int argc, char *argv[]) { llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::cat(OnnxMlirOptions)); - enum EmissionTargetType { - EmitONNXIR, - EmitMLIR, - EmitLLVMIR, - EmitLLVMBC, - }; llvm::cl::opt emissionTarget( llvm::cl::desc("Choose target to emit:"), llvm::cl::values( @@ -105,49 +36,24 @@ int main(int argc, char *argv[]) { llvm::cl::ParseCommandLineOptions(argc, argv, "ONNX MLIR modular optimizer driver\n"); - // Decide if the input file is an ONNX model or a model specified - // in MLIR. The extension of the file is the decider. - string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); - bool inputIsONNX = (extension == "onnx"); - bool inputIsMLIR = (extension == "mlir"); - assert(inputIsONNX != inputIsMLIR && - "Either ONNX model or MLIR file needs to be provided."); - mlir::MLIRContext context; mlir::OwningModuleRef module; - if (inputIsONNX) { - ImportFrontendModelFile(inputFilename, context, module); - } else { - LoadMLIR(inputFilename, context, module); - } + processInputFile(inputFilename, emissionTarget, context, module); mlir::PassManager pm(&context); - pm.addPass(mlir::createDecomposeONNXToONNXPass()); - pm.addPass(mlir::createShapeInferencePass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createShapeInferencePass()); - pm.addPass(mlir::createAttributePromotionPass()); + addONNXToMLIRPasses(pm); if (emissionTarget >= EmitMLIR) { - pm.addPass(mlir::createLowerToKrnlPass()); - // An additional pass of canonicalization is helpful because lowering - // from ONNX dialect to Standard dialect exposes additional canonicalization - // oppertunities. - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createLowerKrnlPass()); + addONNXToKrnlPasses(pm); + addKrnlToAffinePasses(pm); } - if (emissionTarget >= EmitLLVMIR) { - pm.addPass(mlir::createLowerAffinePass()); - pm.addPass(mlir::createLowerToCFGPass()); - pm.addPass(mlir::createKrnlLowerToLLVMPass()); - pm.addPass(mlir::createCanonicalizerPass()); - } + if (emissionTarget >= EmitLLVMIR) + addKrnlToLLVMPasses(pm); if (mlir::failed(pm.run(*module))) return 4; - if (emissionTarget == EmitLLVMBC) { // Write LLVM bitcode to disk. EmitLLVMBitCode(module);