2019-11-12 10:31:56 +08:00
|
|
|
//===--------------------------- main.cpp ---------------------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The IBM Research Authors.
|
|
|
|
//
|
|
|
|
// =============================================================================
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-12-20 03:24:37 +08:00
|
|
|
#include <cmath>
|
|
|
|
#include <cstdlib>
|
|
|
|
#include <iostream>
|
|
|
|
#include <random>
|
|
|
|
#include <tuple>
|
|
|
|
|
|
|
|
#include <boost/date_time/posix_time/posix_time_types.hpp>
|
|
|
|
#include <boost/log/attributes/named_scope.hpp>
|
|
|
|
#include <boost/log/core.hpp>
|
|
|
|
#include <boost/log/expressions.hpp>
|
|
|
|
#include <boost/log/sinks/sync_frontend.hpp>
|
|
|
|
#include <boost/log/sinks/text_file_backend.hpp>
|
|
|
|
#include <boost/log/sinks/text_ostream_backend.hpp>
|
|
|
|
#include <boost/log/sources/logger.hpp>
|
|
|
|
#include <boost/log/support/date_time.hpp>
|
|
|
|
#include <boost/log/trivial.hpp>
|
|
|
|
#include <boost/log/utility/setup/common_attributes.hpp>
|
|
|
|
#include <boost/log/utility/setup/console.hpp>
|
|
|
|
#include <boost/log/utility/setup/file.hpp>
|
|
|
|
|
|
|
|
#include <boost/program_options.hpp>
|
|
|
|
|
2019-11-28 11:56:34 +08:00
|
|
|
#include "llvm/Bitcode/BitcodeWriter.h"
|
2019-11-16 02:10:41 +08:00
|
|
|
#include "llvm/Support/FileUtilities.h"
|
|
|
|
#include "llvm/Support/Regex.h"
|
|
|
|
#include "llvm/Support/SourceMgr.h"
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
#include "src/builder/frontend_dialect_transformer.hpp"
|
2019-11-12 10:31:56 +08:00
|
|
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
2019-11-02 05:09:48 +08:00
|
|
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
2019-11-08 00:42:40 +08:00
|
|
|
#include "src/compiler/pass/passes.hpp"
|
2019-10-09 07:25:59 +08:00
|
|
|
|
2019-11-08 00:42:40 +08:00
|
|
|
#include "mlir/Analysis/Verifier.h"
|
2019-11-28 11:56:34 +08:00
|
|
|
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
|
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
2019-11-08 00:42:40 +08:00
|
|
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
|
|
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
|
|
|
#include "mlir/IR/MLIRContext.h"
|
2019-10-09 07:25:59 +08:00
|
|
|
#include "mlir/IR/Module.h"
|
2019-11-08 00:42:40 +08:00
|
|
|
#include "mlir/Parser.h"
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Pass/PassManager.h"
|
|
|
|
#include "mlir/Target/LLVMIR.h"
|
|
|
|
#include "mlir/Transforms/Passes.h"
|
2019-10-09 07:25:59 +08:00
|
|
|
|
2019-12-20 03:24:37 +08:00
|
|
|
using namespace std;
|
2019-10-30 01:57:56 +08:00
|
|
|
using namespace onnf;
|
2019-12-20 03:24:37 +08:00
|
|
|
|
2019-11-16 02:10:41 +08:00
|
|
|
void LoadMLIR(string inputFilename, mlir::MLIRContext& context,
|
|
|
|
mlir::OwningModuleRef& module) {
|
2019-11-19 08:37:58 +08:00
|
|
|
// Handle '.mlir' input to the ONNF frontend.
|
2019-11-16 02:10:41 +08:00
|
|
|
// The mlir format indicates that one or more of the supported
|
|
|
|
// representations are used in the file.
|
|
|
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
|
|
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
|
|
|
if (std::error_code EC = fileOrErr.getError()) {
|
|
|
|
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parse the input mlir.
|
|
|
|
llvm::SourceMgr sourceMgr;
|
|
|
|
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
|
|
|
module = mlir::parseSourceFile(sourceMgr, &context);
|
|
|
|
if (!module) {
|
|
|
|
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-12-20 02:27:15 +08:00
|
|
|
int main(int ac, char *av[]) {
|
2019-12-20 03:24:37 +08:00
|
|
|
namespace po = boost::program_options;
|
|
|
|
|
|
|
|
po::options_description desc("ONNF available options");
|
|
|
|
// clang-format off
|
|
|
|
desc.add_options()("help", "produce help message")(
|
|
|
|
"onnx-model", po::value<string>()->required(),
|
|
|
|
"onnx model file");
|
|
|
|
// clang-format on
|
|
|
|
|
2019-11-16 02:10:41 +08:00
|
|
|
// Handle command line argument with option names and positional
|
|
|
|
// command line arguments.
|
|
|
|
po::positional_options_description p;
|
|
|
|
p.add("onnx-model", -1);
|
2019-12-20 03:24:37 +08:00
|
|
|
po::variables_map vm;
|
2019-12-20 02:27:15 +08:00
|
|
|
po::store(po::command_line_parser(ac, av).options(desc).positional(p).run(),
|
|
|
|
vm);
|
2019-11-16 02:10:41 +08:00
|
|
|
|
|
|
|
// TODO: allow multiple input files
|
|
|
|
assert(vm.count("onnx-model") < 2 && "At most one input file can be provided!");
|
2019-12-20 03:24:37 +08:00
|
|
|
|
|
|
|
if (vm.count("help")) {
|
|
|
|
cout << desc << endl;
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
2019-11-12 10:31:56 +08:00
|
|
|
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
2019-11-02 05:09:48 +08:00
|
|
|
|
2019-11-08 00:42:40 +08:00
|
|
|
mlir::MLIRContext context;
|
|
|
|
mlir::OwningModuleRef module;
|
|
|
|
|
2019-11-16 02:10:41 +08:00
|
|
|
// Decide if the input file is an ONNX model or a model specified
|
|
|
|
// in MLIR. The extension of the file is the decider.
|
2019-10-09 07:25:59 +08:00
|
|
|
string model_filename = vm["onnx-model"].as<string>();
|
2019-11-16 02:10:41 +08:00
|
|
|
string extension =
|
|
|
|
model_filename.substr(model_filename.find_last_of(".") + 1);
|
|
|
|
bool onnx_model_provided = (extension == "onnx");
|
|
|
|
bool mlir_model_provided = (extension == "mlir");
|
|
|
|
|
|
|
|
if (onnx_model_provided) {
|
|
|
|
ImportFrontendModelFile(model_filename, context, module);
|
|
|
|
} else if (mlir_model_provided) {
|
|
|
|
LoadMLIR(model_filename, context, module);
|
|
|
|
} else {
|
|
|
|
assert(false && "No ONNX or MLIR models provided!");
|
|
|
|
}
|
2019-11-08 00:42:40 +08:00
|
|
|
|
|
|
|
mlir::PassManager pm(&context);
|
|
|
|
pm.addPass(mlir::createShapeInferencePass());
|
2019-11-13 02:37:46 +08:00
|
|
|
pm.addPass(mlir::createCanonicalizerPass());
|
2019-11-27 02:55:44 +08:00
|
|
|
pm.addPass(mlir::createLowerToKrnlPass());
|
2019-11-28 11:56:34 +08:00
|
|
|
pm.addPass(mlir::createLowerKrnlPass());
|
|
|
|
pm.addPass(mlir::createLowerAffinePass());
|
|
|
|
pm.addPass(mlir::createLowerToCFGPass());
|
2019-12-14 04:28:56 +08:00
|
|
|
pm.addPass(mlir::createKrnlLowerToLLVMPass());
|
2019-11-28 11:56:34 +08:00
|
|
|
pm.addPass(mlir::createCanonicalizerPass());
|
2019-11-08 00:42:40 +08:00
|
|
|
pm.run(*module);
|
2019-10-09 07:25:59 +08:00
|
|
|
|
2019-11-28 11:56:34 +08:00
|
|
|
// Write LLVM bitcode to disk.
|
|
|
|
std::error_code EC;
|
2019-12-20 02:27:15 +08:00
|
|
|
llvm::raw_fd_ostream moduleBitcodeStream("model.bc", EC,
|
|
|
|
llvm::sys::fs::F_None);
|
|
|
|
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
|
|
|
|
moduleBitcodeStream);
|
2019-11-28 11:56:34 +08:00
|
|
|
moduleBitcodeStream.flush();
|
|
|
|
|
2019-12-20 03:24:37 +08:00
|
|
|
return 0;
|
|
|
|
}
|