mlir-hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp

325 lines
9.9 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "iostream"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/DebugCounter.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/Timing.h"
#include "mlir/Support/ToolUtilities.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Support/DynamicLibrary.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/JitRunner.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
// #include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/Mangling.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassNameParser.h"
using namespace mlir;
using namespace llvm;
namespace utils{
template <typename T, int N>
struct MemRefDescriptor {
T *allocated;
T *aligned;
int64_t offset;
int64_t sizes[N];
int64_t strides[N];
};
}
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
mlir::initializeLLVMPasses();
// Register any command line options.
// registerAsmPrinterCLOptions();
// registerMLIRContextCLOptions();
// registerPassManagerCLOptions();
// registerDefaultTimingManagerCLOptions();
DebugCounter::registerCLOptions();
mlir::registerAllPasses();
mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses();
mlir::disc_ral::registerAllDiscRalPasses();
mlir::DialectRegistry registry;
mlir::registerAllToLLVMIRTranslations(registry);
mlir::registerAllDialects(registry);
registry.insert<mlir::mhlo::MhloDialect>();
registry.insert<mlir::chlo::HloClientDialect>();
registry.insert<mlir::lmhlo::LmhloDialect>();
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
registry.insert<mlir::disc_ral::RalDialect>();
// failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
// registry,
// /*preloadDialectsInContext=*/false));
// return 0;
std::string errorMessage;
// auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir", &errorMessage);
auto file = mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage);
std::cout<<"errorMessage:" <<errorMessage <<std::endl;
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
SmallVector<const llvm::PassInfo *, 4> passes;
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
auto tmOrError = tmBuilderOrError->createTargetMachine();
auto transformer = mlir::makeLLVMPassesTransformer(
passes, 0, /*targetMachine=*/tmOrError->get(), 0);
MLIRContext context(registry);
context.loadAllAvailableDialects();
OwningModuleRef module(parseSourceFile(sourceMgr, &context));
module->dump();
auto errorHandler = [&](const Twine &msg) {
// emitError(UnknownLoc::get(context)) << msg;
return failure();
};
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \
// RUN: -buffer-deallocation -canonicalize -cse \
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \
// RUN: -convert-std-to-llvm
mlir::PassManager pm(&context, OpPassManager::Nesting::Implicit);
// pm.enableVerifier(verifyPasses);
applyPassManagerCLOptions(pm);
// pm.enableTiming(timing);
pm.addPass(mlir::mhlo::createChloLegalizeToHloPass());
pm.addPass(mlir::mhlo::createLegalizeToLhloPass());
pm.addPass(mlir::createBufferHoistingPass());
pm.addPass(mlir::createBufferDeallocationPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::lmhlo::createLegalizeLhloToLinalgPass());
pm.addPass(mlir::lmhlo::createLhloFuseLinalgPass());
pm.addPass(mlir::createConvertLinalgToLoopsPass());
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createLowerToLLVMPass());
pm.run(*module);
module->dump();
std::cout<<"DEBUG load module success"<<std::endl;
llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default;
// If shared library implements custom mlir-runner library init and destroy
// functions, we'll use them to register the library with the execution
// engine. Otherwise we'll pass library directly to the execution engine.
SmallVector<SmallString<256>, 4> libPaths;
// Use absolute library path so that gdb can find the symbol table.
std::list<std::string> sharedlib;
sharedlib.push_back("/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git");
transform(
sharedlib,
std::back_inserter(libPaths),
[](std::string libPath) {
SmallString<256> absPath(libPath.begin(), libPath.end());
cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
return absPath;
});
// Libraries that we'll pass to the ExecutionEngine for loading.
SmallVector<StringRef, 4> executionEngineLibs;
using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
using MlirRunnerDestroyFn = void (*)();
llvm::StringMap<void *> exportSymbols;
SmallVector<MlirRunnerDestroyFn> destroyFns;
// Handle libraries that do support mlir-runner init/destroy callbacks.
for (auto &libPath : libPaths) {
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
// Library does not support mlir runner, load it with ExecutionEngine.
if (!initSym || !destroySim) {
executionEngineLibs.push_back(libPath);
continue;
}
auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
initFn(exportSymbols);
auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
destroyFns.push_back(destroyFn);
}
// Build a runtime symbol map from the config and exported symbols.
auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
auto symbolMap = llvm::orc::SymbolMap();
for (auto &exportSymbol : exportSymbols)
symbolMap[interner(exportSymbol.getKey())] =
llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
return symbolMap;
};
auto expectedEngine = mlir::ExecutionEngine::create(
module.get(), nullptr, transformer, jitCodeGenOptLevel,
executionEngineLibs);
// if (!expectedEngine)
// return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
engine->registerSymbols(runtimeSymbolMap);
auto expectedFPtr = engine->lookup("main");
if (!expectedFPtr)
// std::cout<<"expectedFPtr "<<expectedFPtr.takeError()<<std::endl;
return 1;
// if (options.dumpObjectFile)
// engine->dumpToObjectFile("a.o");
float rawdata[6] = {0,1,2,3,4,5};
int64_t dims = 1;
utils::MemRefDescriptor<float,1> a{rawdata,rawdata,0,{6},{1}};
utils::MemRefDescriptor<float,1> b{rawdata,rawdata,0,{6},{1}};
utils::MemRefDescriptor<float,1> result_memref;
struct memref_type{
int64_t res_size = 6;
utils::MemRefDescriptor<float,1> *memref;
} result;
result.memref = &result_memref;
struct {
void *data1_size;
void *data1;
void *data2_size;
void *data2;
void *res;
} data;
data.data1_size = &dims;
void * a_ptr = &a;
data.data1 = &a_ptr;
data.data2_size = &dims;
void * b_ptr = &b;
data.data2 = &b_ptr;
void * result_ptr = &result;
data.res = &result;
void (*fptr)(void **) = *expectedFPtr;
(*fptr)((void **)&data);
std::cout<<"result: "<<result.memref->allocated[0]<<std::endl;
std::cout<<"result: "<<result.memref->allocated[1]<<std::endl;
std::cout<<"result: "<<result.memref->allocated[2]<<std::endl;
std::cout<<"result: "<<result.memref->allocated[3]<<std::endl;
std::cout<<"result: "<<result.memref->allocated[4]<<std::endl;
std::cout<<"result: "<<result.memref->allocated[5]<<std::endl;
// Run all dynamic library destroy callbacks to prepare for the shutdown.
llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
return 0;
}