onnx-mlir/src/MainUtils.cpp

520 lines
18 KiB
C++

//===--------------------------- MainUtils.cpp ---------------------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// Functions for adding passes and processing input files.
//
//===----------------------------------------------------------------------===//
#include <cstdio>
#include <cstdlib>
#include <fcntl.h>
#include <regex>
#include <string>
#include <vector>
#include <llvm/Support/FileSystem.h>
#include <llvm/Support/Program.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/SymbolTable.h>
#include "src/ExternalUtil.hpp"
#include "src/MainUtils.hpp"
#ifdef _WIN32
#include <io.h>
#else
#include <unistd.h>
#endif
using namespace std;
using namespace onnx_mlir;
namespace {
llvm::Optional<std::string> getEnvVar(std::string name) {
if (const char *envVerbose = std::getenv(name.c_str()))
return std::string(envVerbose);
return llvm::None;
}
// Runtime directory contains all the libraries, jars, etc. that are
// necessary for running onnx-mlir. It's resolved in the following order:
//
// - if ONNX_MLIR_RUNTIME_DIR is set, use it, otherwise
// - get path from where onnx-mlir is run, if it's of the form
// /foo/bar/bin/onnx-mlir,
// the runtime directory is /foo/bar/lib (note that when onnx-mlir is
// installed system wide, which is typically /usr/local/bin, this will
// correctly resolve to /usr/local/lib), but some systems still have
// lib64 so we check that first. If neither exists, then
// - use CMAKE_INSTALL_PREFIX/lib, which is typically /usr/local/lib
string getRuntimeDir() {
const auto &envDir = getEnvVar("ONNX_MLIR_RUNTIME_DIR");
if (envDir && llvm::sys::fs::exists(envDir.getValue()))
return envDir.getValue();
string execDir = llvm::sys::path::parent_path(kExecPath).str();
if (llvm::sys::path::stem(execDir).str().compare("bin") == 0) {
string p = execDir.substr(0, execDir.size() - 3);
if (llvm::sys::fs::exists(p + "lib64"))
return p + "lib64";
if (llvm::sys::fs::exists(p + "lib"))
return p + "lib";
}
llvm::SmallString<8> instDir64(kInstPath);
llvm::sys::path::append(instDir64, "lib64");
string p = llvm::StringRef(instDir64).str();
if (llvm::sys::fs::exists(p))
return p;
llvm::SmallString<8> instDir(kInstPath);
llvm::sys::path::append(instDir, "lib");
return llvm::StringRef(instDir).str();
}
// Helper struct to make command construction and execution easy & readable.
struct Command {
std::string _path;
std::vector<std::string> _args;
Command(std::string exePath)
: _path(std::move(exePath)),
_args({llvm::sys::path::filename(_path).str()}) {}
// Append a single string argument.
Command &appendStr(const std::string &arg) {
_args.emplace_back(arg);
return *this;
}
// Append a single optional string argument.
Command &appendStrOpt(const llvm::Optional<std::string> &arg) {
if (arg.hasValue())
_args.emplace_back(arg.getValue());
return *this;
}
// Append a list of string arguments.
Command &appendList(const std::vector<std::string> &args) {
_args.insert(_args.end(), args.begin(), args.end());
return *this;
}
// Reset arguments.
Command &resetArgs() {
auto exeFileName = _args.front();
_args.clear();
_args.emplace_back(exeFileName);
return *this;
}
// Execute command.
void exec() {
auto argsRef = std::vector<llvm::StringRef>(_args.begin(), _args.end());
bool verbose = false;
if (const auto &verboseStr = getEnvVar("VERBOSE"))
istringstream(verboseStr.getValue()) >> verbose;
// If in verbose mode, print out command before execution.
if (verbose)
cout << llvm::join(argsRef, " ") << "\n";
int rc = llvm::sys::ExecuteAndWait(_path, llvm::makeArrayRef(argsRef));
if (rc != 0) {
fprintf(stderr, "%s\n", llvm::join(argsRef, " ").c_str());
llvm_unreachable("Command execution failed.");
}
}
};
} // namespace
void setExecPath(const char *argv0, void *fmain) {
string p;
if (!(p = llvm::sys::fs::getMainExecutable(argv0, fmain)).empty())
kExecPath = p;
}
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module) {
// Handle '.mlir' input to the ONNX MLIR frontend.
// The mlir format indicates that one or more of the supported
// representations are used in the file.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return;
}
// Parse the input mlir.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return;
}
}
void genConstPackObj(const mlir::OwningModuleRef &module,
llvm::Optional<string> &constPackObjPath) {
// Extract constant pack file name, which is embedded as a symbol in the
// module being compiled.
auto constPackFilePathSym = (*module).lookupSymbol<mlir::LLVM::GlobalOp>(
mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName());
auto constPackFilePath = constPackFilePathSym.valueAttr()
.dyn_cast_or_null<mlir::StringAttr>()
.getValue()
.str();
llvm::FileRemover constPackRemover(constPackFilePath);
#if __APPLE__
// Create a empty stub file, compile it to an empty obj file.
llvm::SmallVector<char, 20> stubSrcPath;
llvm::sys::fs::createTemporaryFile("stub", "cpp", stubSrcPath);
llvm::FileRemover subSrcRemover(stubSrcPath);
std::string stubSrcPathStr(stubSrcPath.begin(), stubSrcPath.end());
Command createStubObj(/*exePath=*/kCxxPath);
std::string stubObjPathStr = stubSrcPathStr + ".o";
createStubObj.appendList({"-o", stubObjPathStr})
.appendList({"-c", stubSrcPathStr})
.exec();
llvm::FileRemover stubObjRemover(stubObjPathStr);
// Embed data into the empty stub obj file.
constPackObjPath = constPackFilePath + ".o";
Command genParamObj(/*exePath=*/kLinkerPath);
genParamObj.appendStr("-r")
.appendList({"-o", constPackObjPath.getValue()})
.appendList({"-sectcreate", "binary", "param", constPackFilePath})
.appendStr(stubObjPathStr)
.exec();
#elif __linux__
// Create param.o holding packed parameter values.
constPackObjPath = constPackFilePath + ".o";
Command genParamObj(/*exePath=*/kLinkerPath);
genParamObj.appendStr("-r")
.appendList({"-b", "binary"})
.appendList({"-o", constPackObjPath.getValue()})
.appendStr(constPackFilePath)
.exec();
// Figure out what is the default symbol name describing the start/end
// address of the embedded data.
std::regex e("[^0-9A-Za-z]");
auto sanitizedName =
"_binary_" + std::regex_replace(constPackFilePath, e, "_");
// Rename the symbols to saner ones expected by the runtime function.
Command redefineSym(/*exePath=*/kObjCopyPath);
redefineSym.appendStr("--redefine-sym")
.appendStr(sanitizedName + "_start=_binary_param_bin_start")
.appendStr(constPackObjPath.getValue())
.exec();
redefineSym.resetArgs()
.appendStr("--redefine-sym")
.appendStr(sanitizedName + "_end=_binary_param_bin_end")
.appendStr(constPackObjPath.getValue())
.exec();
#else
llvm::SmallVector<char, 10> permConstPackFileName(
constPackFilePath.begin(), constPackFilePath.end());
llvm::sys::path::replace_extension(permConstPackFileName, "bin");
std::string permConstPackFileNameStr(
permConstPackFileName.begin(), permConstPackFileName.end());
auto constPackFileName = llvm::sys::path::filename(outputBaseName) + "." +
llvm::sys::path::filename(permConstPackFileNameStr);
llvm::sys::fs::rename(constPackFilePath, constPackFileName);
mlir::Builder builder(*module);
(*module)
.lookupSymbol<mlir::LLVM::GlobalOp>(
mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName())
.valueAttr(builder.getStringAttr(constPackFileName.str()));
(*module)
.lookupSymbol<mlir::LLVM::GlobalOp>(
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName())
.valueAttr(builder.getI64IntegerAttr(constPackFileName.str().size()));
#endif
}
// Write LLVM bitcode.
void genLLVMBitcode(const mlir::OwningModuleRef &module, string bitcodePath) {
error_code error;
llvm::raw_fd_ostream moduleBitcodeStream(
bitcodePath, error, llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
moduleBitcodeStream.flush();
}
// Compile LLVM bitcode to object file.
void genModelObject(const mlir::OwningModuleRef &module, string bitcodePath,
string modelObjPath) {
Command llvmToObj(/*exePath=*/kLlcPath);
llvmToObj.appendStr("-filetype=obj")
.appendStr("-relocation-model=pic")
.appendList({"-o", modelObjPath})
.appendStr(bitcodePath)
.exec();
}
void genJniObject(const mlir::OwningModuleRef &module, string jniSharedLibPath,
string jniObjPath) {
Command ar(/*exePath=*/kArPath);
ar.appendStr("x").appendStr(jniSharedLibPath).appendStr(jniObjPath).exec();
}
// Link everything into a shared object.
void genSharedLib(const mlir::OwningModuleRef &module,
string modelSharedLibPath, std::vector<string> opts,
std::vector<string> objs, std::vector<string> libs) {
string runtimeDirInclFlag = "-L" + getRuntimeDir();
Command link(kCxxPath);
link.appendList(opts)
.appendList(objs)
.appendList({"-o", modelSharedLibPath})
.appendStrOpt(runtimeDirInclFlag)
.appendList(libs)
.exec();
}
// Create jar containing java runtime and model shared library (which includes
// jni runtime).
void genJniJar(const mlir::OwningModuleRef &module, string modelSharedLibPath,
string modelJniJarPath) {
llvm::SmallString<8> runtimeDir(getRuntimeDir());
llvm::sys::path::append(runtimeDir, "javaruntime.jar");
string javaRuntimeJarPath = llvm::StringRef(runtimeDir).str();
// Copy javaruntime.jar to model jar.
llvm::sys::fs::copy_file(javaRuntimeJarPath, modelJniJarPath);
// Add shared library to model jar.
Command jar(kJarPath);
jar.appendList({"uf", modelJniJarPath}).appendStr(modelSharedLibPath).exec();
}
void compileModuleToSharedLibrary(
const mlir::OwningModuleRef &module, std::string outputBaseName) {
llvm::Optional<string> constPackObjPath;
genConstPackObj(module, constPackObjPath);
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
string bitcodePath = outputBaseName + ".bc";
genLLVMBitcode(module, bitcodePath);
llvm::FileRemover bitcodeRemover(bitcodePath);
string modelObjPath = outputBaseName + ".o";
genModelObject(module, bitcodePath, modelObjPath);
llvm::FileRemover modelObjRemover(modelObjPath);
string modelSharedLibPath = outputBaseName + ".so";
genSharedLib(module, modelSharedLibPath, {"-shared", "-fPIC"},
{constPackObjPath.getValueOr(""), modelObjPath},
{"-lEmbeddedDataLoader", "-lcruntime"});
}
void compileModuleToJniJar(
const mlir::OwningModuleRef &module, std::string outputBaseName) {
llvm::Optional<string> constPackObjPath;
genConstPackObj(module, constPackObjPath);
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
string bitcodePath = outputBaseName + ".bc";
genLLVMBitcode(module, bitcodePath);
llvm::FileRemover bitcodeRemover(bitcodePath);
string modelObjPath = outputBaseName + ".o";
genModelObject(module, bitcodePath, modelObjPath);
llvm::FileRemover modelObjRemover(modelObjPath);
string jniSharedLibPath = getRuntimeDir() + "/libjniruntime.a";
string jniObjPath = "jnidummy.c.o";
genJniObject(module, jniSharedLibPath, jniObjPath);
llvm::FileRemover jniObjRemover(jniObjPath);
string modelSharedLibPath = "libmodel.so";
genSharedLib(module, modelSharedLibPath,
{"-shared", "-fPIC", "-z", "noexecstack"},
{constPackObjPath.getValueOr(""), modelObjPath, jniObjPath},
{"-lEmbeddedDataLoader", "-lcruntime", "-ljniruntime"});
llvm::FileRemover modelSharedLibRemover(modelSharedLibPath);
string modelJniJarPath = outputBaseName + ".jar";
genJniJar(module, modelSharedLibPath, modelJniJarPath);
}
void registerDialects() {
mlir::registerDialect<mlir::AffineDialect>();
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::scf::SCFDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>();
}
void addONNXToMLIRPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createDecomposeONNXToONNXPass());
pm.addPass(mlir::createConstPropONNXToONNXPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createAttributePromotionPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createAttributePromotionPass());
}
void addONNXToKrnlPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerToKrnlPass());
pm.addPass(mlir::createPackKrnlGlobalConstantsPass());
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
pm.addPass(mlir::createCanonicalizerPass());
// TODO: make this pass optional:
pm.addPass(mlir::createKrnlEnableMemoryPoolPass());
pm.addPass(mlir::createKrnlBundleMemoryPoolsPass());
pm.addPass(mlir::createCanonicalizerPass());
}
void addKrnlToAffinePasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerKrnlPass());
// Fuse loops in Affine dialect.
// pm.addPass(mlir::createLoopFusionPass());
}
void addKrnlToLLVMPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(mlir::createKrnlLowerToLLVMPass());
pm.addPass(mlir::createCanonicalizerPass());
}
void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// Decide if the input file is an ONNX model or a model specified
// in MLIR. The extension of the file is the decider.
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
bool inputIsONNX = (extension == "onnx");
bool inputIsMLIR = (extension == "mlir");
assert(inputIsONNX != inputIsMLIR &&
"Either ONNX model or MLIR file needs to be provided.");
if (inputIsONNX) {
ImportFrontendModelFile(inputFilename, context, module);
} else {
LoadMLIR(inputFilename, context, module);
}
}
void outputCode(
mlir::OwningModuleRef &module, string filename, string extension) {
// Start a separate process to redirect the model output. I/O redirection
// changes will not be visible to the parent process.
string tempFilename = filename + extension;
#ifdef _WIN32
// copy original stderr file number
int stderrOrigin = _dup(_fileno(stderr));
freopen(tempFilename.c_str(), "w", stderr);
module->dump();
fflush(stderr);
// set modified stderr as original stderr
_dup2(stderrOrigin, _fileno(stderr));
#else
if (fork() == 0) {
freopen(tempFilename.c_str(), "w", stderr);
module->dump();
fclose(stderr);
exit(0);
}
#endif
}
void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// For EmitONNXIR and EmitMLIR the constant value are embedded in the code
// thus making the code hard to read. These values can be elided by emitting
// two versions of the same source code:
// (1) a version with all the constant values included meant for being passed
// back to onnx-mlir for further processing and stored in:
//
// <name>.onnx.mlir
//
// (2) a version without constants meant for being inspected by users and
// stored in:
//
// <name>.tmp
//
// In the case of the LLVM Dialect IR the constant values are grouped
// outside the function code at the beginning of the file in which case the
// elision of these constants is not strictly required. Elision is also not
// necessary when emitting the .bc file.
if (emissionTarget == EmitLib) {
// Write LLVM bitcode to disk, compile & link.
compileModuleToSharedLibrary(module, outputBaseName);
printf("Shared library %s.so has been compiled.\n", outputBaseName.c_str());
} else if (emissionTarget == EmitJNI) {
compileModuleToJniJar(module, outputBaseName);
printf("JNI archive %s.jar has been compiled.\n", outputBaseName.c_str());
} else {
// Emit the version with all constants included.
outputCode(module, outputBaseName, ".onnx.mlir");
printf("Full MLIR code written to: \n\t%s\n\n",
(outputBaseName + ".onnx.mlir").c_str());
// Apply specific passes to clean up the code where necessary.
mlir::PassManager cleanSourcePM(&context);
if (emissionTarget == EmitONNXIR || emissionTarget == EmitONNXBasic)
cleanSourcePM.addPass(mlir::createElideConstantValuePass());
if (emissionTarget == EmitMLIR)
cleanSourcePM.addPass(mlir::createElideConstGlobalValuePass());
if (emissionTarget == EmitONNXBasic || emissionTarget == EmitONNXIR ||
emissionTarget == EmitMLIR) {
if (mlir::failed(cleanSourcePM.run(*module)))
llvm::errs() << "Could not apply simplification passes.\n";
outputCode(module, outputBaseName, ".tmp");
printf("Constant-free MLIR Code written to: \n\t%s\n\n",
(outputBaseName + ".tmp").c_str());
printf("Use:\n\t%s\nto continue lowering the code to other dialects.\n",
(outputBaseName + ".onnx.mlir").c_str());
}
}
}
int compileModule(mlir::OwningModuleRef &module, mlir::MLIRContext &context,
std::string outputBaseName, EmissionTargetType emissionTarget) {
mlir::PassManager pm(&context);
if (emissionTarget >= EmitONNXIR) {
addONNXToMLIRPasses(pm);
}
if (emissionTarget >= EmitMLIR) {
addONNXToKrnlPasses(pm);
addKrnlToAffinePasses(pm);
}
if (emissionTarget >= EmitLLVMIR)
addKrnlToLLVMPasses(pm);
if (mlir::failed(pm.run(*module)))
return 4;
emitOutputFiles(outputBaseName, emissionTarget, context, module);
return 0;
}