[NFC] Reorganize main function. (#44)

* Reorganize main function.

* Follow review comments.

* Use new file names.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-03-24 13:48:54 -04:00 committed by GitHub
parent ddff0f1256
commit 1777c07b1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 175 additions and 103 deletions

View File

@ -6,7 +6,10 @@ add_subdirectory(Tool)
add_subdirectory(Builder)
add_subdirectory(Runtime)
add_executable(onnx-mlir main.cpp)
add_executable(onnx-mlir
MainUtils.hpp
MainUtils.cpp
main.cpp)
target_link_libraries(onnx-mlir
${MLIRLibs}
OMBuilder

99
src/MainUtils.cpp Normal file
View File

@ -0,0 +1,99 @@
//===--------------------------- main_utils.cpp ---------------------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// Functions for adding passes and processing input files.
//
//===----------------------------------------------------------------------===//
#include "src/MainUtils.hpp"
using namespace std;
using namespace onnx_mlir;
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module) {
// Handle '.mlir' input to the ONNX MLIR frontend.
// The mlir format indicates that one or more of the supported
// representations are used in the file.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return;
}
// Parse the input mlir.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return;
}
}
void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
error_code error;
llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error,
llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
moduleBitcodeStream);
moduleBitcodeStream.flush();
}
void registerDialects() {
mlir::registerDialect<mlir::AffineOpsDialect>();
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>();
}
void addONNXToMLIRPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createDecomposeONNXToONNXPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createAttributePromotionPass());
}
void addONNXToKrnlPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerToKrnlPass());
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
pm.addPass(mlir::createCanonicalizerPass());
}
void addKrnlToAffinePasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerKrnlPass());
}
void addKrnlToLLVMPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(mlir::createKrnlLowerToLLVMPass());
pm.addPass(mlir::createCanonicalizerPass());
}
void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// Decide if the input file is an ONNX model or a model specified
// in MLIR. The extension of the file is the decider.
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
bool inputIsONNX = (extension == "onnx");
bool inputIsMLIR = (extension == "mlir");
assert(inputIsONNX != inputIsMLIR &&
"Either ONNX model or MLIR file needs to be provided.");
if (inputIsONNX) {
ImportFrontendModelFile(inputFilename, context, module);
} else {
LoadMLIR(inputFilename, context, module);
}
}

64
src/MainUtils.hpp Normal file
View File

@ -0,0 +1,64 @@
//===--------------------------- main_utils.hpp ---------------------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// Functions for adding passes and processing input files.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <cmath>
#include <iostream>
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "src/Builder/FrontendDialectTransformer.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/InitAllDialects.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/Passes.h"
enum EmissionTargetType {
EmitONNXIR,
EmitMLIR,
EmitLLVMIR,
EmitLLVMBC,
};
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module);
void EmitLLVMBitCode(const mlir::OwningModuleRef &module);
void registerDialects();
void addONNXToMLIRPasses(mlir::PassManager &pm);
void addONNXToKrnlPasses(mlir::PassManager &pm);
void addKrnlToAffinePasses(mlir::PassManager &pm);
void addKrnlToLLVMPasses(mlir::PassManager &pm);
void processInputFile(std::string inputFilename,
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
mlir::OwningModuleRef &module);

View File

@ -6,76 +6,13 @@
//
//===----------------------------------------------------------------------===//
#include <cmath>
#include <iostream>
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "src/Builder/FrontendDialectTransformer.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/InitAllDialects.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/Passes.h"
void EmitLLVMBitCode(const mlir::OwningModuleRef &module);
#include "src/MainUtils.hpp"
using namespace std;
using namespace onnx_mlir;
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module) {
// Handle '.mlir' input to the ONNX MLIR frontend.
// The mlir format indicates that one or more of the supported
// representations are used in the file.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return;
}
// Parse the input mlir.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return;
}
}
void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
error_code error;
llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error,
llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
moduleBitcodeStream);
moduleBitcodeStream.flush();
}
int main(int argc, char *argv[]) {
mlir::registerDialect<mlir::AffineOpsDialect>();
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>();
registerDialects();
llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options",
"These are frontend options.");
@ -83,12 +20,6 @@ int main(int argc, char *argv[]) {
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::cat(OnnxMlirOptions));
enum EmissionTargetType {
EmitONNXIR,
EmitMLIR,
EmitLLVMIR,
EmitLLVMBC,
};
llvm::cl::opt<EmissionTargetType> emissionTarget(
llvm::cl::desc("Choose target to emit:"),
llvm::cl::values(
@ -105,49 +36,24 @@ int main(int argc, char *argv[]) {
llvm::cl::ParseCommandLineOptions(argc, argv,
"ONNX MLIR modular optimizer driver\n");
// Decide if the input file is an ONNX model or a model specified
// in MLIR. The extension of the file is the decider.
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
bool inputIsONNX = (extension == "onnx");
bool inputIsMLIR = (extension == "mlir");
assert(inputIsONNX != inputIsMLIR &&
"Either ONNX model or MLIR file needs to be provided.");
mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (inputIsONNX) {
ImportFrontendModelFile(inputFilename, context, module);
} else {
LoadMLIR(inputFilename, context, module);
}
processInputFile(inputFilename, emissionTarget, context, module);
mlir::PassManager pm(&context);
pm.addPass(mlir::createDecomposeONNXToONNXPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createAttributePromotionPass());
addONNXToMLIRPasses(pm);
if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass());
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createLowerKrnlPass());
addONNXToKrnlPasses(pm);
addKrnlToAffinePasses(pm);
}
if (emissionTarget >= EmitLLVMIR) {
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(mlir::createKrnlLowerToLLVMPass());
pm.addPass(mlir::createCanonicalizerPass());
}
if (emissionTarget >= EmitLLVMIR)
addKrnlToLLVMPasses(pm);
if (mlir::failed(pm.run(*module)))
return 4;
if (emissionTarget == EmitLLVMBC) {
// Write LLVM bitcode to disk.
EmitLLVMBitCode(module);