[NFC] Reorganize main function. (#44)
* Reorganize main function. * Follow review comments. * Use new file names.
This commit is contained in:
parent
ddff0f1256
commit
1777c07b1e
|
@ -6,7 +6,10 @@ add_subdirectory(Tool)
|
||||||
add_subdirectory(Builder)
|
add_subdirectory(Builder)
|
||||||
add_subdirectory(Runtime)
|
add_subdirectory(Runtime)
|
||||||
|
|
||||||
add_executable(onnx-mlir main.cpp)
|
add_executable(onnx-mlir
|
||||||
|
MainUtils.hpp
|
||||||
|
MainUtils.cpp
|
||||||
|
main.cpp)
|
||||||
target_link_libraries(onnx-mlir
|
target_link_libraries(onnx-mlir
|
||||||
${MLIRLibs}
|
${MLIRLibs}
|
||||||
OMBuilder
|
OMBuilder
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
//===--------------------------- main_utils.cpp ---------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Functions for adding passes and processing input files.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/MainUtils.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace onnx_mlir;
|
||||||
|
|
||||||
|
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
||||||
|
mlir::OwningModuleRef &module) {
|
||||||
|
// Handle '.mlir' input to the ONNX MLIR frontend.
|
||||||
|
// The mlir format indicates that one or more of the supported
|
||||||
|
// representations are used in the file.
|
||||||
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||||
|
if (std::error_code EC = fileOrErr.getError()) {
|
||||||
|
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the input mlir.
|
||||||
|
llvm::SourceMgr sourceMgr;
|
||||||
|
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||||
|
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||||
|
if (!module) {
|
||||||
|
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
|
||||||
|
error_code error;
|
||||||
|
llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error,
|
||||||
|
llvm::sys::fs::F_None);
|
||||||
|
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
|
||||||
|
moduleBitcodeStream);
|
||||||
|
moduleBitcodeStream.flush();
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerDialects() {
|
||||||
|
mlir::registerDialect<mlir::AffineOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
||||||
|
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void addONNXToMLIRPasses(mlir::PassManager &pm) {
|
||||||
|
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
||||||
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
|
pm.addPass(mlir::createAttributePromotionPass());
|
||||||
|
}
|
||||||
|
|
||||||
|
void addONNXToKrnlPasses(mlir::PassManager &pm) {
|
||||||
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
|
// An additional pass of canonicalization is helpful because lowering
|
||||||
|
// from ONNX dialect to Standard dialect exposes additional canonicalization
|
||||||
|
// oppertunities.
|
||||||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
|
}
|
||||||
|
|
||||||
|
void addKrnlToAffinePasses(mlir::PassManager &pm) {
|
||||||
|
pm.addPass(mlir::createLowerKrnlPass());
|
||||||
|
}
|
||||||
|
|
||||||
|
void addKrnlToLLVMPasses(mlir::PassManager &pm) {
|
||||||
|
pm.addPass(mlir::createLowerAffinePass());
|
||||||
|
pm.addPass(mlir::createLowerToCFGPass());
|
||||||
|
pm.addPass(mlir::createKrnlLowerToLLVMPass());
|
||||||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
|
}
|
||||||
|
|
||||||
|
void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
|
||||||
|
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||||
|
// Decide if the input file is an ONNX model or a model specified
|
||||||
|
// in MLIR. The extension of the file is the decider.
|
||||||
|
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
|
||||||
|
bool inputIsONNX = (extension == "onnx");
|
||||||
|
bool inputIsMLIR = (extension == "mlir");
|
||||||
|
assert(inputIsONNX != inputIsMLIR &&
|
||||||
|
"Either ONNX model or MLIR file needs to be provided.");
|
||||||
|
|
||||||
|
|
||||||
|
if (inputIsONNX) {
|
||||||
|
ImportFrontendModelFile(inputFilename, context, module);
|
||||||
|
} else {
|
||||||
|
LoadMLIR(inputFilename, context, module);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,64 @@
|
||||||
|
//===--------------------------- main_utils.hpp ---------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Functions for adding passes and processing input files.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "llvm/Bitcode/BitcodeWriter.h"
|
||||||
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
#include "llvm/Support/FileUtilities.h"
|
||||||
|
#include "llvm/Support/InitLLVM.h"
|
||||||
|
#include "llvm/Support/Regex.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
|
||||||
|
#include "src/Builder/FrontendDialectTransformer.hpp"
|
||||||
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||||
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||||
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
|
#include "mlir/InitAllDialects.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Module.h"
|
||||||
|
#include "mlir/Parser.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Target/LLVMIR.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
|
enum EmissionTargetType {
|
||||||
|
EmitONNXIR,
|
||||||
|
EmitMLIR,
|
||||||
|
EmitLLVMIR,
|
||||||
|
EmitLLVMBC,
|
||||||
|
};
|
||||||
|
|
||||||
|
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
|
||||||
|
mlir::OwningModuleRef &module);
|
||||||
|
|
||||||
|
void EmitLLVMBitCode(const mlir::OwningModuleRef &module);
|
||||||
|
|
||||||
|
void registerDialects();
|
||||||
|
|
||||||
|
void addONNXToMLIRPasses(mlir::PassManager &pm);
|
||||||
|
|
||||||
|
void addONNXToKrnlPasses(mlir::PassManager &pm);
|
||||||
|
|
||||||
|
void addKrnlToAffinePasses(mlir::PassManager &pm);
|
||||||
|
|
||||||
|
void addKrnlToLLVMPasses(mlir::PassManager &pm);
|
||||||
|
|
||||||
|
void processInputFile(std::string inputFilename,
|
||||||
|
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
|
||||||
|
mlir::OwningModuleRef &module);
|
110
src/main.cpp
110
src/main.cpp
|
@ -6,76 +6,13 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <cmath>
|
#include "src/MainUtils.hpp"
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "llvm/Bitcode/BitcodeWriter.h"
|
|
||||||
#include "llvm/Support/CommandLine.h"
|
|
||||||
#include "llvm/Support/FileUtilities.h"
|
|
||||||
#include "llvm/Support/InitLLVM.h"
|
|
||||||
#include "llvm/Support/Regex.h"
|
|
||||||
#include "llvm/Support/SourceMgr.h"
|
|
||||||
|
|
||||||
#include "src/Builder/FrontendDialectTransformer.hpp"
|
|
||||||
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
#include "src/Pass/Passes.hpp"
|
|
||||||
|
|
||||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
|
||||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
|
||||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
|
||||||
#include "mlir/InitAllDialects.h"
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
|
||||||
#include "mlir/IR/Module.h"
|
|
||||||
#include "mlir/Parser.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "mlir/Pass/PassManager.h"
|
|
||||||
#include "mlir/Target/LLVMIR.h"
|
|
||||||
#include "mlir/Transforms/Passes.h"
|
|
||||||
|
|
||||||
void EmitLLVMBitCode(const mlir::OwningModuleRef &module);
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
|
|
||||||
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
|
||||||
mlir::OwningModuleRef &module) {
|
|
||||||
// Handle '.mlir' input to the ONNX MLIR frontend.
|
|
||||||
// The mlir format indicates that one or more of the supported
|
|
||||||
// representations are used in the file.
|
|
||||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
|
||||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
|
||||||
if (std::error_code EC = fileOrErr.getError()) {
|
|
||||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the input mlir.
|
|
||||||
llvm::SourceMgr sourceMgr;
|
|
||||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
|
||||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
|
||||||
if (!module) {
|
|
||||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
|
|
||||||
error_code error;
|
|
||||||
llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error,
|
|
||||||
llvm::sys::fs::F_None);
|
|
||||||
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
|
|
||||||
moduleBitcodeStream);
|
|
||||||
moduleBitcodeStream.flush();
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
mlir::registerDialect<mlir::AffineOpsDialect>();
|
registerDialects();
|
||||||
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
|
||||||
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
|
||||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
|
||||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
|
||||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
|
||||||
|
|
||||||
llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options",
|
llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options",
|
||||||
"These are frontend options.");
|
"These are frontend options.");
|
||||||
|
@ -83,12 +20,6 @@ int main(int argc, char *argv[]) {
|
||||||
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
|
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
enum EmissionTargetType {
|
|
||||||
EmitONNXIR,
|
|
||||||
EmitMLIR,
|
|
||||||
EmitLLVMIR,
|
|
||||||
EmitLLVMBC,
|
|
||||||
};
|
|
||||||
llvm::cl::opt<EmissionTargetType> emissionTarget(
|
llvm::cl::opt<EmissionTargetType> emissionTarget(
|
||||||
llvm::cl::desc("Choose target to emit:"),
|
llvm::cl::desc("Choose target to emit:"),
|
||||||
llvm::cl::values(
|
llvm::cl::values(
|
||||||
|
@ -105,49 +36,24 @@ int main(int argc, char *argv[]) {
|
||||||
llvm::cl::ParseCommandLineOptions(argc, argv,
|
llvm::cl::ParseCommandLineOptions(argc, argv,
|
||||||
"ONNX MLIR modular optimizer driver\n");
|
"ONNX MLIR modular optimizer driver\n");
|
||||||
|
|
||||||
// Decide if the input file is an ONNX model or a model specified
|
|
||||||
// in MLIR. The extension of the file is the decider.
|
|
||||||
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
|
|
||||||
bool inputIsONNX = (extension == "onnx");
|
|
||||||
bool inputIsMLIR = (extension == "mlir");
|
|
||||||
assert(inputIsONNX != inputIsMLIR &&
|
|
||||||
"Either ONNX model or MLIR file needs to be provided.");
|
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
mlir::OwningModuleRef module;
|
mlir::OwningModuleRef module;
|
||||||
if (inputIsONNX) {
|
processInputFile(inputFilename, emissionTarget, context, module);
|
||||||
ImportFrontendModelFile(inputFilename, context, module);
|
|
||||||
} else {
|
|
||||||
LoadMLIR(inputFilename, context, module);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::PassManager pm(&context);
|
mlir::PassManager pm(&context);
|
||||||
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
addONNXToMLIRPasses(pm);
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
|
||||||
pm.addPass(mlir::createAttributePromotionPass());
|
|
||||||
|
|
||||||
if (emissionTarget >= EmitMLIR) {
|
if (emissionTarget >= EmitMLIR) {
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
addONNXToKrnlPasses(pm);
|
||||||
// An additional pass of canonicalization is helpful because lowering
|
addKrnlToAffinePasses(pm);
|
||||||
// from ONNX dialect to Standard dialect exposes additional canonicalization
|
|
||||||
// oppertunities.
|
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
|
||||||
pm.addPass(mlir::createLowerKrnlPass());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (emissionTarget >= EmitLLVMIR) {
|
if (emissionTarget >= EmitLLVMIR)
|
||||||
pm.addPass(mlir::createLowerAffinePass());
|
addKrnlToLLVMPasses(pm);
|
||||||
pm.addPass(mlir::createLowerToCFGPass());
|
|
||||||
pm.addPass(mlir::createKrnlLowerToLLVMPass());
|
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mlir::failed(pm.run(*module)))
|
if (mlir::failed(pm.run(*module)))
|
||||||
return 4;
|
return 4;
|
||||||
|
|
||||||
|
|
||||||
if (emissionTarget == EmitLLVMBC) {
|
if (emissionTarget == EmitLLVMBC) {
|
||||||
// Write LLVM bitcode to disk.
|
// Write LLVM bitcode to disk.
|
||||||
EmitLLVMBitCode(module);
|
EmitLLVMBitCode(module);
|
||||||
|
|
Loading…
Reference in New Issue