onnx-mlir/src/MainUtils.cpp

167 lines
6.2 KiB
C++

//===--------------------------- main_utils.cpp ---------------------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// Functions for adding passes and processing input files.
//
//===----------------------------------------------------------------------===//
#include "src/MainUtils.hpp"
#include <fcntl.h>
#include <stdio.h>
#include <unistd.h>
using namespace std;
using namespace onnx_mlir;
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module) {
// Handle '.mlir' input to the ONNX MLIR frontend.
// The mlir format indicates that one or more of the supported
// representations are used in the file.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return;
}
// Parse the input mlir.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return;
}
}
void EmitLLVMBitCode(
const mlir::OwningModuleRef &module, string outputFilename) {
error_code error;
llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error,
llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
moduleBitcodeStream);
moduleBitcodeStream.flush();
}
void registerDialects() {
mlir::registerDialect<mlir::AffineDialect>();
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>();
}
void addONNXToMLIRPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createDecomposeONNXToONNXPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createAttributePromotionPass());
}
void addONNXToKrnlPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerToKrnlPass());
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
pm.addPass(mlir::createCanonicalizerPass());
}
void addKrnlToAffinePasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerKrnlPass());
}
void addKrnlToLLVMPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(mlir::createKrnlLowerToLLVMPass());
pm.addPass(mlir::createCanonicalizerPass());
}
void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// Decide if the input file is an ONNX model or a model specified
// in MLIR. The extension of the file is the decider.
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
bool inputIsONNX = (extension == "onnx");
bool inputIsMLIR = (extension == "mlir");
assert(inputIsONNX != inputIsMLIR &&
"Either ONNX model or MLIR file needs to be provided.");
if (inputIsONNX) {
ImportFrontendModelFile(inputFilename, context, module);
} else {
LoadMLIR(inputFilename, context, module);
}
}
void outputCode(
mlir::OwningModuleRef &module, string filename, string extension) {
// Start a separate process to redirect the model output. I/O redirection
// changes will not be visible to the parent process.
if (fork() == 0) {
const char * tempFilename = (filename + extension).c_str();
freopen(tempFilename, "w", stderr);
module->dump();
fclose(stderr);
exit(0);
}
}
void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// For EmitONNXIR and EmitMLIR the constant value are embedded in the code
// thus making the code hard to read. These values can be elided by emitting
// two versions of the same source code:
// (1) a version with all the constant values included meant for being passed
// back to onnx-mlir for further processing and stored in:
//
// <name>.onnx.mlir
//
// (2) a version without constants meant for being inspected by users and
// stored in:
//
// <name>.mlir
//
// In the case of the LLVM Dialect IR the constant values are grouped
// outside the function code at the beginning of the file in which case the
// elision of these constants is not strictly required. Elision is also not
// necessary when emitting the .bc file.
if (emissionTarget == EmitLLVMBC) {
// Write LLVM bitcode to disk.
string outputFilename = outputBaseName + ".bc";
EmitLLVMBitCode(module, outputFilename);
printf("LLVM bitcode written to %s\n", outputFilename.c_str());
} else {
// Emit the version with all constants included.
outputCode(module, outputBaseName, ".onnx.mlir");
printf("Full MLIR code written to: \n\t%s\n\n",
(outputBaseName + ".onnx.mlir").c_str());
// Apply specific passes to clean up the code where necessary.
mlir::PassManager cleanSourcePM(&context);
if (emissionTarget == EmitONNXIR || emissionTarget == EmitONNXBasic)
cleanSourcePM.addPass(mlir::createElideConstantValuePass());
if (emissionTarget == EmitMLIR)
cleanSourcePM.addPass(mlir::createElideConstGlobalValuePass());
if (emissionTarget == EmitONNXBasic || emissionTarget == EmitONNXIR ||
emissionTarget == EmitMLIR) {
if (mlir::failed(cleanSourcePM.run(*module)))
llvm::errs() << "Could not apply simplification passes.\n";
outputCode(module, outputBaseName, ".mlir");
printf("Constant-free MLIR Code written to: \n\t%s\n\n",
(outputBaseName + ".mlir").c_str());
printf("Use:\n\t%s\nto continue lowering the code to other dialects.\n",
(outputBaseName + ".onnx.mlir").c_str());
}
}
}