add cpu runner in llvm

This commit is contained in:
colin.liang 2021-07-12 21:03:29 +08:00
parent 20ff8b4c93
commit 62e7b883c7
3 changed files with 295 additions and 6 deletions

8
BUILD
View File

@ -1373,12 +1373,20 @@ cc_binary(
":hlo", ":hlo",
":lhlo", ":lhlo",
":lhlo_gpu", ":lhlo_gpu",
"@llvm-project//llvm:AllTargetsAsmParsers",
"@llvm-project//llvm:AllTargetsCodeGens",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:ExecutionEngine",
"@llvm-project//llvm:Option",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:MlirJitRunner",
], ],
) )

39
tests/test.mlir Normal file
View File

@ -0,0 +1,39 @@
func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
func @main() {
%c0 = constant 0 : index
%c1 = constant 1 : index
// Initialize input.
%input = memref.alloc() : memref<2x3xf32>
%dim_x = memref.dim %input, %c0 : memref<2x3xf32>
%dim_y = memref.dim %input, %c1 : memref<2x3xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
%i_i64 = index_cast %i : index to i64
%i_f32 = sitofp %i_i64 : i64 to f32
memref.store %i_f32, %input[%i, %j] : memref<2x3xf32>
}
%unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK: [0, 0, 0]
// CHECK: [1, 1, 1]
%in = memref.tensor_load %input : memref<2x3xf32>
%add = "mhlo.add"(%in, %in) {name = "add.3"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%output = memref.buffer_cast %add : memref<2x3xf32>
%unranked_output = memref.cast %output : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK: [0, 0, 0]
// CHECK: [2, 2, 2]
return
}
// ./mlir-hlo-opt -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -canonicalize -cse -lhlo-legalize-to-linalg -convert-linalg-to-loops -lower-affine -convert-scf-to-std -convert-std-to-llvm ../../../tests/test.mlir > a.mlir
// /root/mlir-hlo/llvm-build/bin/mlir-cpu-runner --entry-point-result=void -shared-libs=/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git a.mlir > b.mlir
// /root/mlir-hlo/llvm-build/bin/FileCheck --input-file b.mlir ../../../tests/test.mlir

View File

@ -13,23 +13,94 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "iostream"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/DebugCounter.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/Timing.h"
#include "mlir/Support/ToolUtilities.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Support/DynamicLibrary.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/JitRunner.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir/InitAllDialects.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
// #include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/Mangling.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassNameParser.h"
using namespace mlir;
using namespace llvm;
int main(int argc, char **argv) { int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
mlir::initializeLLVMPasses();
// Register any command line options.
// registerAsmPrinterCLOptions();
// registerMLIRContextCLOptions();
// registerPassManagerCLOptions();
// registerDefaultTimingManagerCLOptions();
DebugCounter::registerCLOptions();
mlir::registerAllPasses(); mlir::registerAllPasses();
mlir::mhlo::registerAllMhloPasses(); mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses(); mlir::lmhlo::registerAllLmhloPasses();
mlir::disc_ral::registerAllDiscRalPasses(); mlir::disc_ral::registerAllDiscRalPasses();
mlir::DialectRegistry registry; mlir::DialectRegistry registry;
mlir::registerAllToLLVMIRTranslations(registry);
mlir::registerAllDialects(registry); mlir::registerAllDialects(registry);
registry.insert<mlir::mhlo::MhloDialect>(); registry.insert<mlir::mhlo::MhloDialect>();
registry.insert<mlir::chlo::HloClientDialect>(); registry.insert<mlir::chlo::HloClientDialect>();
@ -37,7 +108,178 @@ int main(int argc, char **argv) {
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>(); registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
registry.insert<mlir::disc_ral::RalDialect>(); registry.insert<mlir::disc_ral::RalDialect>();
return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
registry, // failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
/*preloadDialectsInContext=*/false)); // registry,
// /*preloadDialectsInContext=*/false));
// return 0;
std::string errorMessage;
// auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir", &errorMessage);
auto file = mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage);
std::cout<<"errorMessage:" <<errorMessage <<std::endl;
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
SmallVector<const llvm::PassInfo *, 4> passes;
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
auto tmOrError = tmBuilderOrError->createTargetMachine();
auto transformer = mlir::makeLLVMPassesTransformer(
passes, 0, /*targetMachine=*/tmOrError->get(), 0);
MLIRContext context(registry);
context.loadAllAvailableDialects();
OwningModuleRef module(parseSourceFile(sourceMgr, &context));
module->dump();
auto errorHandler = [&](const Twine &msg) {
// emitError(UnknownLoc::get(context)) << msg;
return failure();
};
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \
// RUN: -buffer-deallocation -canonicalize -cse \
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \
// RUN: -convert-std-to-llvm
mlir::PassManager pm(&context, OpPassManager::Nesting::Implicit);
// pm.enableVerifier(verifyPasses);
applyPassManagerCLOptions(pm);
// pm.enableTiming(timing);
pm.addPass(mlir::mhlo::createChloLegalizeToHloPass());
pm.addPass(mlir::mhlo::createLegalizeToLhloPass());
pm.addPass(mlir::createBufferHoistingPass());
pm.addPass(mlir::createBufferDeallocationPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::lmhlo::createLegalizeLhloToLinalgPass());
pm.addPass(mlir::lmhlo::createLhloFuseLinalgPass());
pm.addPass(mlir::createConvertLinalgToLoopsPass());
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createLowerToLLVMPass());
pm.run(*module);
module->dump();
std::cout<<"DEBUG load module success"<<std::endl;
llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default;
// If shared library implements custom mlir-runner library init and destroy
// functions, we'll use them to register the library with the execution
// engine. Otherwise we'll pass library directly to the execution engine.
SmallVector<SmallString<256>, 4> libPaths;
// Use absolute library path so that gdb can find the symbol table.
std::list<std::string> sharedlib;
sharedlib.push_back("/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git");
transform(
sharedlib,
std::back_inserter(libPaths),
[](std::string libPath) {
SmallString<256> absPath(libPath.begin(), libPath.end());
cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
return absPath;
});
// Libraries that we'll pass to the ExecutionEngine for loading.
SmallVector<StringRef, 4> executionEngineLibs;
using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
using MlirRunnerDestroyFn = void (*)();
llvm::StringMap<void *> exportSymbols;
SmallVector<MlirRunnerDestroyFn> destroyFns;
// Handle libraries that do support mlir-runner init/destroy callbacks.
for (auto &libPath : libPaths) {
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
// Library does not support mlir runner, load it with ExecutionEngine.
if (!initSym || !destroySim) {
executionEngineLibs.push_back(libPath);
continue;
}
auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
initFn(exportSymbols);
auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
destroyFns.push_back(destroyFn);
}
// Build a runtime symbol map from the config and exported symbols.
auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
auto symbolMap = llvm::orc::SymbolMap();
for (auto &exportSymbol : exportSymbols)
symbolMap[interner(exportSymbol.getKey())] =
llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
return symbolMap;
};
auto expectedEngine = mlir::ExecutionEngine::create(
module.get(), nullptr, transformer, jitCodeGenOptLevel,
executionEngineLibs);
// if (!expectedEngine)
// return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
engine->registerSymbols(runtimeSymbolMap);
auto expectedFPtr = engine->lookup("main");
if (!expectedFPtr)
// std::cout<<"expectedFPtr "<<expectedFPtr.takeError()<<std::endl;
return 1;
// if (options.dumpObjectFile)
// engine->dumpToObjectFile("a.o");
int res;
struct {
void *data;
} data;
data.data = &res;
void (*fptr)(void **) = *expectedFPtr;
(*fptr)((void **)&data);
std::cout<<"result.data "<<res<<std::endl;
// Run all dynamic library destroy callbacks to prepare for the shutdown.
llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
return 0;
} }