71 lines
2.3 KiB
C++
71 lines
2.3 KiB
C++
//===--------------------------- main.cpp ---------------------------------===//
|
|
//
|
|
// Copyright 2019 The IBM Research Authors.
|
|
//
|
|
// =============================================================================
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "src/MainUtils.hpp"
|
|
|
|
using namespace std;
|
|
using namespace onnx_mlir;
|
|
|
|
int main(int argc, char *argv[]) {
|
|
registerDialects();
|
|
|
|
llvm::cl::OptionCategory OnnxMlirOptions(
|
|
"ONNX MLIR Options", "These are frontend options.");
|
|
llvm::cl::opt<string> inputFilename(llvm::cl::Positional,
|
|
llvm::cl::desc("<input file>"), llvm::cl::init("-"),
|
|
llvm::cl::cat(OnnxMlirOptions));
|
|
|
|
llvm::cl::opt<EmissionTargetType> emissionTarget(
|
|
llvm::cl::desc("Choose target to emit:"),
|
|
llvm::cl::values(
|
|
clEnumVal(EmitONNXBasic,
|
|
"Ingest ONNX and emit the basic ONNX operations without"
|
|
"inferred shapes."),
|
|
clEnumVal(
|
|
EmitONNXIR, "Ingest ONNX and emit corresponding ONNX dialect."),
|
|
clEnumVal(
|
|
EmitMLIR, "Lower model to MLIR built-in transformation dialect."),
|
|
clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
|
|
clEnumVal(EmitLib, "Lower model to LLVM IR, emit (to file) "
|
|
"LLVM bitcode for model, compile and link it to a "
|
|
"shared library.")),
|
|
llvm::cl::init(EmitLib), llvm::cl::cat(OnnxMlirOptions));
|
|
|
|
llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);
|
|
llvm::cl::ParseCommandLineOptions(
|
|
argc, argv, "ONNX MLIR modular optimizer driver\n");
|
|
|
|
mlir::MLIRContext context;
|
|
mlir::OwningModuleRef module;
|
|
processInputFile(inputFilename, emissionTarget, context, module);
|
|
|
|
// Input file base name.
|
|
string outputBaseName =
|
|
inputFilename.substr(0, inputFilename.find_last_of("."));
|
|
|
|
mlir::PassManager pm(&context);
|
|
if (emissionTarget >= EmitONNXIR) {
|
|
addONNXToMLIRPasses(pm);
|
|
}
|
|
|
|
if (emissionTarget >= EmitMLIR) {
|
|
addONNXToKrnlPasses(pm);
|
|
addKrnlToAffinePasses(pm);
|
|
}
|
|
|
|
if (emissionTarget >= EmitLLVMIR)
|
|
addKrnlToLLVMPasses(pm);
|
|
|
|
if (mlir::failed(pm.run(*module)))
|
|
return 4;
|
|
|
|
emitOutputFiles(outputBaseName, emissionTarget, context, module);
|
|
|
|
return 0;
|
|
}
|