mlir-hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp

286 lines
8.8 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "iostream"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/DebugCounter.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/Timing.h"
#include "mlir/Support/ToolUtilities.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Support/DynamicLibrary.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/JitRunner.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
// #include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/Mangling.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassNameParser.h"
using namespace mlir;
using namespace llvm;
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
mlir::initializeLLVMPasses();
// Register any command line options.
// registerAsmPrinterCLOptions();
// registerMLIRContextCLOptions();
// registerPassManagerCLOptions();
// registerDefaultTimingManagerCLOptions();
DebugCounter::registerCLOptions();
mlir::registerAllPasses();
mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses();
mlir::disc_ral::registerAllDiscRalPasses();
mlir::DialectRegistry registry;
mlir::registerAllToLLVMIRTranslations(registry);
mlir::registerAllDialects(registry);
registry.insert<mlir::mhlo::MhloDialect>();
registry.insert<mlir::chlo::HloClientDialect>();
registry.insert<mlir::lmhlo::LmhloDialect>();
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
registry.insert<mlir::disc_ral::RalDialect>();
// failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
// registry,
// /*preloadDialectsInContext=*/false));
// return 0;
std::string errorMessage;
// auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir", &errorMessage);
auto file = mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage);
std::cout<<"errorMessage:" <<errorMessage <<std::endl;
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
SmallVector<const llvm::PassInfo *, 4> passes;
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
auto tmOrError = tmBuilderOrError->createTargetMachine();
auto transformer = mlir::makeLLVMPassesTransformer(
passes, 0, /*targetMachine=*/tmOrError->get(), 0);
MLIRContext context(registry);
context.loadAllAvailableDialects();
OwningModuleRef module(parseSourceFile(sourceMgr, &context));
module->dump();
auto errorHandler = [&](const Twine &msg) {
// emitError(UnknownLoc::get(context)) << msg;
return failure();
};
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \
// RUN: -buffer-deallocation -canonicalize -cse \
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \
// RUN: -convert-std-to-llvm
mlir::PassManager pm(&context, OpPassManager::Nesting::Implicit);
// pm.enableVerifier(verifyPasses);
applyPassManagerCLOptions(pm);
// pm.enableTiming(timing);
pm.addPass(mlir::mhlo::createChloLegalizeToHloPass());
pm.addPass(mlir::mhlo::createLegalizeToLhloPass());
pm.addPass(mlir::createBufferHoistingPass());
pm.addPass(mlir::createBufferDeallocationPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::lmhlo::createLegalizeLhloToLinalgPass());
pm.addPass(mlir::lmhlo::createLhloFuseLinalgPass());
pm.addPass(mlir::createConvertLinalgToLoopsPass());
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createLowerToLLVMPass());
pm.run(*module);
module->dump();
std::cout<<"DEBUG load module success"<<std::endl;
llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default;
// If shared library implements custom mlir-runner library init and destroy
// functions, we'll use them to register the library with the execution
// engine. Otherwise we'll pass library directly to the execution engine.
SmallVector<SmallString<256>, 4> libPaths;
// Use absolute library path so that gdb can find the symbol table.
std::list<std::string> sharedlib;
sharedlib.push_back("/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git");
transform(
sharedlib,
std::back_inserter(libPaths),
[](std::string libPath) {
SmallString<256> absPath(libPath.begin(), libPath.end());
cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
return absPath;
});
// Libraries that we'll pass to the ExecutionEngine for loading.
SmallVector<StringRef, 4> executionEngineLibs;
using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
using MlirRunnerDestroyFn = void (*)();
llvm::StringMap<void *> exportSymbols;
SmallVector<MlirRunnerDestroyFn> destroyFns;
// Handle libraries that do support mlir-runner init/destroy callbacks.
for (auto &libPath : libPaths) {
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
// Library does not support mlir runner, load it with ExecutionEngine.
if (!initSym || !destroySim) {
executionEngineLibs.push_back(libPath);
continue;
}
auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
initFn(exportSymbols);
auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
destroyFns.push_back(destroyFn);
}
// Build a runtime symbol map from the config and exported symbols.
auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
auto symbolMap = llvm::orc::SymbolMap();
for (auto &exportSymbol : exportSymbols)
symbolMap[interner(exportSymbol.getKey())] =
llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
return symbolMap;
};
auto expectedEngine = mlir::ExecutionEngine::create(
module.get(), nullptr, transformer, jitCodeGenOptLevel,
executionEngineLibs);
// if (!expectedEngine)
// return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
engine->registerSymbols(runtimeSymbolMap);
auto expectedFPtr = engine->lookup("main");
if (!expectedFPtr)
// std::cout<<"expectedFPtr "<<expectedFPtr.takeError()<<std::endl;
return 1;
// if (options.dumpObjectFile)
// engine->dumpToObjectFile("a.o");
int res;
struct {
void *data;
} data;
data.data = &res;
void (*fptr)(void **) = *expectedFPtr;
(*fptr)((void **)&data);
std::cout<<"result.data "<<res<<std::endl;
// Run all dynamic library destroy callbacks to prepare for the shutdown.
llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
return 0;
}