181 lines
6.5 KiB
C++
181 lines
6.5 KiB
C++
//===--------------------------- main_utils.cpp ---------------------------===//
|
|
//
|
|
// Copyright 2019-2020 The IBM Research Authors.
|
|
//
|
|
// =============================================================================
|
|
//
|
|
// Functions for adding passes and processing input files.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "src/MainUtils.hpp"
|
|
#include <fcntl.h>
|
|
#include <stdio.h>
|
|
|
|
#ifdef _WIN32
|
|
#include <io.h>
|
|
#else
|
|
#include <unistd.h>
|
|
#endif
|
|
|
|
using namespace std;
|
|
using namespace onnx_mlir;
|
|
|
|
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
|
mlir::OwningModuleRef &module) {
|
|
// Handle '.mlir' input to the ONNX MLIR frontend.
|
|
// The mlir format indicates that one or more of the supported
|
|
// representations are used in the file.
|
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
|
if (std::error_code EC = fileOrErr.getError()) {
|
|
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
|
return;
|
|
}
|
|
|
|
// Parse the input mlir.
|
|
llvm::SourceMgr sourceMgr;
|
|
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
|
module = mlir::parseSourceFile(sourceMgr, &context);
|
|
if (!module) {
|
|
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
|
return;
|
|
}
|
|
}
|
|
|
|
void EmitLLVMBitCode(
|
|
const mlir::OwningModuleRef &module, string outputFilename) {
|
|
error_code error;
|
|
llvm::raw_fd_ostream moduleBitcodeStream(
|
|
outputFilename, error, llvm::sys::fs::F_None);
|
|
llvm::WriteBitcodeToFile(
|
|
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
|
moduleBitcodeStream.flush();
|
|
}
|
|
|
|
void registerDialects() {
|
|
mlir::registerDialect<mlir::AffineDialect>();
|
|
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
|
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
|
mlir::registerDialect<mlir::StandardOpsDialect>();
|
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
|
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
|
}
|
|
|
|
void addONNXToMLIRPasses(mlir::PassManager &pm) {
|
|
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
|
pm.addPass(mlir::createShapeInferencePass());
|
|
pm.addPass(mlir::createCanonicalizerPass());
|
|
pm.addPass(mlir::createShapeInferencePass());
|
|
pm.addPass(mlir::createAttributePromotionPass());
|
|
}
|
|
|
|
void addONNXToKrnlPasses(mlir::PassManager &pm) {
|
|
pm.addPass(mlir::createLowerToKrnlPass());
|
|
// An additional pass of canonicalization is helpful because lowering
|
|
// from ONNX dialect to Standard dialect exposes additional canonicalization
|
|
// oppertunities.
|
|
pm.addPass(mlir::createCanonicalizerPass());
|
|
}
|
|
|
|
void addKrnlToAffinePasses(mlir::PassManager &pm) {
|
|
pm.addPass(mlir::createLowerKrnlPass());
|
|
}
|
|
|
|
void addKrnlToLLVMPasses(mlir::PassManager &pm) {
|
|
pm.addPass(mlir::createLowerAffinePass());
|
|
pm.addPass(mlir::createLowerToCFGPass());
|
|
pm.addPass(mlir::createKrnlLowerToLLVMPass());
|
|
pm.addPass(mlir::createCanonicalizerPass());
|
|
}
|
|
|
|
void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
|
|
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
|
// Decide if the input file is an ONNX model or a model specified
|
|
// in MLIR. The extension of the file is the decider.
|
|
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
|
|
bool inputIsONNX = (extension == "onnx");
|
|
bool inputIsMLIR = (extension == "mlir");
|
|
assert(inputIsONNX != inputIsMLIR &&
|
|
"Either ONNX model or MLIR file needs to be provided.");
|
|
|
|
if (inputIsONNX) {
|
|
ImportFrontendModelFile(inputFilename, context, module);
|
|
} else {
|
|
LoadMLIR(inputFilename, context, module);
|
|
}
|
|
}
|
|
|
|
void outputCode(
|
|
mlir::OwningModuleRef &module, string filename, string extension) {
|
|
// Start a separate process to redirect the model output. I/O redirection
|
|
// changes will not be visible to the parent process.
|
|
string tempFilename = filename + extension;
|
|
#ifdef _WIN32
|
|
// copy original stderr file number
|
|
int stderrOrigin = _dup(_fileno(stderr));
|
|
freopen(tempFilename.c_str(), "w", stderr);
|
|
module->dump();
|
|
fflush(stderr);
|
|
// set modified stderr as original stderr
|
|
_dup2(stderrOrigin, _fileno(stderr));
|
|
#else
|
|
if (fork() == 0) {
|
|
freopen(tempFilename.c_str(), "w", stderr);
|
|
module->dump();
|
|
fclose(stderr);
|
|
exit(0);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
|
|
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
|
// For EmitONNXIR and EmitMLIR the constant value are embedded in the code
|
|
// thus making the code hard to read. These values can be elided by emitting
|
|
// two versions of the same source code:
|
|
// (1) a version with all the constant values included meant for being passed
|
|
// back to onnx-mlir for further processing and stored in:
|
|
//
|
|
// <name>.onnx.mlir
|
|
//
|
|
// (2) a version without constants meant for being inspected by users and
|
|
// stored in:
|
|
//
|
|
// <name>.mlir
|
|
//
|
|
// In the case of the LLVM Dialect IR the constant values are grouped
|
|
// outside the function code at the beginning of the file in which case the
|
|
// elision of these constants is not strictly required. Elision is also not
|
|
// necessary when emitting the .bc file.
|
|
if (emissionTarget == EmitLLVMBC) {
|
|
// Write LLVM bitcode to disk.
|
|
string outputFilename = outputBaseName + ".bc";
|
|
EmitLLVMBitCode(module, outputFilename);
|
|
printf("LLVM bitcode written to %s\n", outputFilename.c_str());
|
|
} else {
|
|
// Emit the version with all constants included.
|
|
outputCode(module, outputBaseName, ".onnx.mlir");
|
|
printf("Full MLIR code written to: \n\t%s\n\n",
|
|
(outputBaseName + ".onnx.mlir").c_str());
|
|
|
|
// Apply specific passes to clean up the code where necessary.
|
|
mlir::PassManager cleanSourcePM(&context);
|
|
if (emissionTarget == EmitONNXIR || emissionTarget == EmitONNXBasic)
|
|
cleanSourcePM.addPass(mlir::createElideConstantValuePass());
|
|
if (emissionTarget == EmitMLIR)
|
|
cleanSourcePM.addPass(mlir::createElideConstGlobalValuePass());
|
|
|
|
if (emissionTarget == EmitONNXBasic || emissionTarget == EmitONNXIR ||
|
|
emissionTarget == EmitMLIR) {
|
|
if (mlir::failed(cleanSourcePM.run(*module)))
|
|
llvm::errs() << "Could not apply simplification passes.\n";
|
|
outputCode(module, outputBaseName, ".mlir");
|
|
printf("Constant-free MLIR Code written to: \n\t%s\n\n",
|
|
(outputBaseName + ".mlir").c_str());
|
|
|
|
printf("Use:\n\t%s\nto continue lowering the code to other dialects.\n",
|
|
(outputBaseName + ".onnx.mlir").c_str());
|
|
}
|
|
}
|
|
} |