100 lines
3.5 KiB
C++
100 lines
3.5 KiB
C++
|
//===--------------------------- main_utils.cpp ---------------------------===//
|
||
|
//
|
||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||
|
//
|
||
|
// =============================================================================
|
||
|
//
|
||
|
// Functions for adding passes and processing input files.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "src/MainUtils.hpp"
|
||
|
|
||
|
using namespace std;
|
||
|
using namespace onnx_mlir;
|
||
|
|
||
|
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
||
|
mlir::OwningModuleRef &module) {
|
||
|
// Handle '.mlir' input to the ONNX MLIR frontend.
|
||
|
// The mlir format indicates that one or more of the supported
|
||
|
// representations are used in the file.
|
||
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||
|
if (std::error_code EC = fileOrErr.getError()) {
|
||
|
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
// Parse the input mlir.
|
||
|
llvm::SourceMgr sourceMgr;
|
||
|
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||
|
module = mlir::parseSourceFile(sourceMgr, &context);
|
||
|
if (!module) {
|
||
|
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
|
||
|
error_code error;
|
||
|
llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error,
|
||
|
llvm::sys::fs::F_None);
|
||
|
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
|
||
|
moduleBitcodeStream);
|
||
|
moduleBitcodeStream.flush();
|
||
|
}
|
||
|
|
||
|
void registerDialects() {
|
||
|
mlir::registerDialect<mlir::AffineOpsDialect>();
|
||
|
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
||
|
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
||
|
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||
|
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||
|
}
|
||
|
|
||
|
void addONNXToMLIRPasses(mlir::PassManager &pm) {
|
||
|
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
||
|
pm.addPass(mlir::createShapeInferencePass());
|
||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||
|
pm.addPass(mlir::createShapeInferencePass());
|
||
|
pm.addPass(mlir::createAttributePromotionPass());
|
||
|
}
|
||
|
|
||
|
void addONNXToKrnlPasses(mlir::PassManager &pm) {
|
||
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||
|
// An additional pass of canonicalization is helpful because lowering
|
||
|
// from ONNX dialect to Standard dialect exposes additional canonicalization
|
||
|
// oppertunities.
|
||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||
|
}
|
||
|
|
||
|
void addKrnlToAffinePasses(mlir::PassManager &pm) {
|
||
|
pm.addPass(mlir::createLowerKrnlPass());
|
||
|
}
|
||
|
|
||
|
void addKrnlToLLVMPasses(mlir::PassManager &pm) {
|
||
|
pm.addPass(mlir::createLowerAffinePass());
|
||
|
pm.addPass(mlir::createLowerToCFGPass());
|
||
|
pm.addPass(mlir::createKrnlLowerToLLVMPass());
|
||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||
|
}
|
||
|
|
||
|
void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
|
||
|
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||
|
// Decide if the input file is an ONNX model or a model specified
|
||
|
// in MLIR. The extension of the file is the decider.
|
||
|
string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1);
|
||
|
bool inputIsONNX = (extension == "onnx");
|
||
|
bool inputIsMLIR = (extension == "mlir");
|
||
|
assert(inputIsONNX != inputIsMLIR &&
|
||
|
"Either ONNX model or MLIR file needs to be provided.");
|
||
|
|
||
|
|
||
|
if (inputIsONNX) {
|
||
|
ImportFrontendModelFile(inputFilename, context, module);
|
||
|
} else {
|
||
|
LoadMLIR(inputFilename, context, module);
|
||
|
}
|
||
|
}
|