325 lines
9.9 KiB
C++
325 lines
9.9 KiB
C++
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "iostream"
|
|
|
|
#include "mlir/Support/MlirOptMain.h"
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/Location.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/Parser.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/DebugCounter.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
#include "mlir/Support/Timing.h"
|
|
#include "mlir/Support/ToolUtilities.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/FileUtilities.h"
|
|
#include "llvm/Support/InitLLVM.h"
|
|
#include "llvm/Support/Regex.h"
|
|
#include "llvm/Support/SourceMgr.h"
|
|
#include "llvm/Support/StringSaver.h"
|
|
#include "llvm/Support/ToolOutputFile.h"
|
|
#include "llvm/Support/DynamicLibrary.h"
|
|
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/ExecutionEngine/JitRunner.h"
|
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
|
#include "mlir/Target/LLVMIR/Dialect/All.h"
|
|
|
|
|
|
|
|
#include "mlir/InitAllDialects.h"
|
|
#include "mlir/InitAllPasses.h"
|
|
|
|
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
|
#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
|
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
|
|
|
|
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
|
// #include "mlir/IR/BuiltinTypes.h"
|
|
|
|
|
|
#include "llvm/Support/TargetSelect.h"
|
|
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
|
|
#include "llvm/ExecutionEngine/Orc/Mangling.h"
|
|
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/LLVMContext.h"
|
|
#include "llvm/IR/LegacyPassNameParser.h"
|
|
|
|
|
|
using namespace mlir;
|
|
using namespace llvm;
|
|
|
|
namespace utils{
|
|
template <typename T, int N>
|
|
struct MemRefDescriptor {
|
|
T *allocated;
|
|
T *aligned;
|
|
int64_t offset;
|
|
int64_t sizes[N];
|
|
int64_t strides[N];
|
|
};
|
|
}
|
|
|
|
|
|
int main(int argc, char **argv) {
|
|
|
|
llvm::InitLLVM y(argc, argv);
|
|
llvm::InitializeNativeTarget();
|
|
llvm::InitializeNativeTargetAsmPrinter();
|
|
llvm::InitializeNativeTargetAsmParser();
|
|
mlir::initializeLLVMPasses();
|
|
|
|
// Register any command line options.
|
|
// registerAsmPrinterCLOptions();
|
|
// registerMLIRContextCLOptions();
|
|
// registerPassManagerCLOptions();
|
|
// registerDefaultTimingManagerCLOptions();
|
|
DebugCounter::registerCLOptions();
|
|
|
|
|
|
mlir::registerAllPasses();
|
|
mlir::mhlo::registerAllMhloPasses();
|
|
mlir::lmhlo::registerAllLmhloPasses();
|
|
mlir::disc_ral::registerAllDiscRalPasses();
|
|
|
|
mlir::DialectRegistry registry;
|
|
mlir::registerAllToLLVMIRTranslations(registry);
|
|
mlir::registerAllDialects(registry);
|
|
registry.insert<mlir::mhlo::MhloDialect>();
|
|
registry.insert<mlir::chlo::HloClientDialect>();
|
|
registry.insert<mlir::lmhlo::LmhloDialect>();
|
|
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
|
|
registry.insert<mlir::disc_ral::RalDialect>();
|
|
|
|
|
|
// failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
|
|
// registry,
|
|
// /*preloadDialectsInContext=*/false));
|
|
// return 0;
|
|
|
|
|
|
|
|
std::string errorMessage;
|
|
|
|
// auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir", &errorMessage);
|
|
auto file = mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage);
|
|
std::cout<<"errorMessage:" <<errorMessage <<std::endl;
|
|
|
|
SourceMgr sourceMgr;
|
|
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
|
|
|
|
|
|
SmallVector<const llvm::PassInfo *, 4> passes;
|
|
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
|
|
auto tmOrError = tmBuilderOrError->createTargetMachine();
|
|
|
|
auto transformer = mlir::makeLLVMPassesTransformer(
|
|
passes, 0, /*targetMachine=*/tmOrError->get(), 0);
|
|
|
|
|
|
|
|
MLIRContext context(registry);
|
|
context.loadAllAvailableDialects();
|
|
OwningModuleRef module(parseSourceFile(sourceMgr, &context));
|
|
module->dump();
|
|
|
|
auto errorHandler = [&](const Twine &msg) {
|
|
// emitError(UnknownLoc::get(context)) << msg;
|
|
return failure();
|
|
};
|
|
|
|
|
|
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
|
|
// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \
|
|
// RUN: -buffer-deallocation -canonicalize -cse \
|
|
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
|
|
// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \
|
|
// RUN: -convert-std-to-llvm
|
|
|
|
mlir::PassManager pm(&context, OpPassManager::Nesting::Implicit);
|
|
// pm.enableVerifier(verifyPasses);
|
|
applyPassManagerCLOptions(pm);
|
|
// pm.enableTiming(timing);
|
|
|
|
pm.addPass(mlir::mhlo::createChloLegalizeToHloPass());
|
|
pm.addPass(mlir::mhlo::createLegalizeToLhloPass());
|
|
|
|
pm.addPass(mlir::createBufferHoistingPass());
|
|
pm.addPass(mlir::createBufferDeallocationPass());
|
|
|
|
pm.addPass(mlir::createCanonicalizerPass());
|
|
pm.addPass(mlir::createCSEPass());
|
|
|
|
pm.addPass(mlir::lmhlo::createLegalizeLhloToLinalgPass());
|
|
pm.addPass(mlir::lmhlo::createLhloFuseLinalgPass());
|
|
pm.addPass(mlir::createConvertLinalgToLoopsPass());
|
|
|
|
pm.addPass(mlir::createLowerAffinePass());
|
|
pm.addPass(mlir::createLowerToCFGPass());
|
|
pm.addPass(mlir::createCanonicalizerPass());
|
|
pm.addPass(mlir::createCSEPass());
|
|
pm.addPass(mlir::createLowerToLLVMPass());
|
|
|
|
pm.run(*module);
|
|
module->dump();
|
|
|
|
|
|
|
|
std::cout<<"DEBUG load module success"<<std::endl;
|
|
|
|
llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default;
|
|
|
|
// If shared library implements custom mlir-runner library init and destroy
|
|
// functions, we'll use them to register the library with the execution
|
|
// engine. Otherwise we'll pass library directly to the execution engine.
|
|
SmallVector<SmallString<256>, 4> libPaths;
|
|
|
|
// Use absolute library path so that gdb can find the symbol table.
|
|
|
|
std::list<std::string> sharedlib;
|
|
sharedlib.push_back("/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git");
|
|
|
|
transform(
|
|
sharedlib,
|
|
std::back_inserter(libPaths),
|
|
[](std::string libPath) {
|
|
SmallString<256> absPath(libPath.begin(), libPath.end());
|
|
cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
|
|
return absPath;
|
|
});
|
|
|
|
// Libraries that we'll pass to the ExecutionEngine for loading.
|
|
SmallVector<StringRef, 4> executionEngineLibs;
|
|
|
|
using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
|
|
using MlirRunnerDestroyFn = void (*)();
|
|
|
|
llvm::StringMap<void *> exportSymbols;
|
|
SmallVector<MlirRunnerDestroyFn> destroyFns;
|
|
|
|
|
|
// Handle libraries that do support mlir-runner init/destroy callbacks.
|
|
for (auto &libPath : libPaths) {
|
|
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
|
|
void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
|
|
void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
|
|
|
|
// Library does not support mlir runner, load it with ExecutionEngine.
|
|
if (!initSym || !destroySim) {
|
|
executionEngineLibs.push_back(libPath);
|
|
continue;
|
|
}
|
|
|
|
auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
|
|
initFn(exportSymbols);
|
|
|
|
auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
|
|
destroyFns.push_back(destroyFn);
|
|
}
|
|
|
|
// Build a runtime symbol map from the config and exported symbols.
|
|
auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
|
|
auto symbolMap = llvm::orc::SymbolMap();
|
|
for (auto &exportSymbol : exportSymbols)
|
|
symbolMap[interner(exportSymbol.getKey())] =
|
|
llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
|
|
return symbolMap;
|
|
};
|
|
|
|
|
|
|
|
|
|
auto expectedEngine = mlir::ExecutionEngine::create(
|
|
module.get(), nullptr, transformer, jitCodeGenOptLevel,
|
|
executionEngineLibs);
|
|
// if (!expectedEngine)
|
|
// return expectedEngine.takeError();
|
|
|
|
|
|
|
|
auto engine = std::move(*expectedEngine);
|
|
engine->registerSymbols(runtimeSymbolMap);
|
|
|
|
auto expectedFPtr = engine->lookup("main");
|
|
if (!expectedFPtr)
|
|
// std::cout<<"expectedFPtr "<<expectedFPtr.takeError()<<std::endl;
|
|
return 1;
|
|
|
|
// if (options.dumpObjectFile)
|
|
// engine->dumpToObjectFile("a.o");
|
|
|
|
|
|
float rawdata[6] = {0,1,2,3,4,5};
|
|
int64_t dims = 1;
|
|
utils::MemRefDescriptor<float,1> a{rawdata,rawdata,0,{6},{1}};
|
|
utils::MemRefDescriptor<float,1> b{rawdata,rawdata,0,{6},{1}};
|
|
utils::MemRefDescriptor<float,1> result_memref;
|
|
|
|
struct memref_type{
|
|
int64_t res_size = 6;
|
|
utils::MemRefDescriptor<float,1> *memref;
|
|
} result;
|
|
result.memref = &result_memref;
|
|
|
|
struct {
|
|
void *data1_size;
|
|
void *data1;
|
|
void *data2_size;
|
|
void *data2;
|
|
void *res;
|
|
} data;
|
|
|
|
data.data1_size = &dims;
|
|
void * a_ptr = &a;
|
|
data.data1 = &a_ptr;
|
|
data.data2_size = &dims;
|
|
void * b_ptr = &b;
|
|
data.data2 = &b_ptr;
|
|
void * result_ptr = &result;
|
|
data.res = &result;
|
|
|
|
void (*fptr)(void **) = *expectedFPtr;
|
|
(*fptr)((void **)&data);
|
|
|
|
std::cout<<"result: "<<result.memref->allocated[0]<<std::endl;
|
|
std::cout<<"result: "<<result.memref->allocated[1]<<std::endl;
|
|
std::cout<<"result: "<<result.memref->allocated[2]<<std::endl;
|
|
std::cout<<"result: "<<result.memref->allocated[3]<<std::endl;
|
|
std::cout<<"result: "<<result.memref->allocated[4]<<std::endl;
|
|
std::cout<<"result: "<<result.memref->allocated[5]<<std::endl;
|
|
|
|
// Run all dynamic library destroy callbacks to prepare for the shutdown.
|
|
llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
|
|
|
|
return 0;
|
|
}
|