add cpu runner in llvm
This commit is contained in:
parent
20ff8b4c93
commit
62e7b883c7
8
BUILD
8
BUILD
|
@ -1373,12 +1373,20 @@ cc_binary(
|
|||
":hlo",
|
||||
":lhlo",
|
||||
":lhlo_gpu",
|
||||
"@llvm-project//llvm:AllTargetsAsmParsers",
|
||||
"@llvm-project//llvm:AllTargetsCodeGens",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:ExecutionEngine",
|
||||
"@llvm-project//llvm:Option",
|
||||
"@llvm-project//llvm:OrcJIT",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:Target",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:MlirJitRunner",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
|
||||
|
||||
func @main() {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
|
||||
// Initialize input.
|
||||
%input = memref.alloc() : memref<2x3xf32>
|
||||
%dim_x = memref.dim %input, %c0 : memref<2x3xf32>
|
||||
%dim_y = memref.dim %input, %c1 : memref<2x3xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
||||
%i_i64 = index_cast %i : index to i64
|
||||
%i_f32 = sitofp %i_i64 : i64 to f32
|
||||
memref.store %i_f32, %input[%i, %j] : memref<2x3xf32>
|
||||
}
|
||||
%unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||
// CHECK: [0, 0, 0]
|
||||
// CHECK: [1, 1, 1]
|
||||
|
||||
%in = memref.tensor_load %input : memref<2x3xf32>
|
||||
|
||||
%add = "mhlo.add"(%in, %in) {name = "add.3"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
|
||||
%output = memref.buffer_cast %add : memref<2x3xf32>
|
||||
%unranked_output = memref.cast %output : memref<2x3xf32> to memref<*xf32>
|
||||
|
||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||
// CHECK: [0, 0, 0]
|
||||
// CHECK: [2, 2, 2]
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ./mlir-hlo-opt -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -canonicalize -cse -lhlo-legalize-to-linalg -convert-linalg-to-loops -lower-affine -convert-scf-to-std -convert-std-to-llvm ../../../tests/test.mlir > a.mlir
|
||||
// /root/mlir-hlo/llvm-build/bin/mlir-cpu-runner --entry-point-result=void -shared-libs=/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git a.mlir > b.mlir
|
||||
// /root/mlir-hlo/llvm-build/bin/FileCheck --input-file b.mlir ../../../tests/test.mlir
|
|
@ -13,23 +13,94 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "iostream"
|
||||
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/DebugCounter.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
#include "mlir/Support/Timing.h"
|
||||
#include "mlir/Support/ToolUtilities.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FileUtilities.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/StringSaver.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "llvm/Support/DynamicLibrary.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/ExecutionEngine/JitRunner.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/All.h"
|
||||
|
||||
|
||||
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
|
||||
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
|
||||
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
// #include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
|
||||
#include "llvm/ExecutionEngine/Orc/Mangling.h"
|
||||
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/LegacyPassNameParser.h"
|
||||
|
||||
|
||||
using namespace mlir;
|
||||
using namespace llvm;
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
llvm::InitializeNativeTargetAsmParser();
|
||||
mlir::initializeLLVMPasses();
|
||||
|
||||
// Register any command line options.
|
||||
// registerAsmPrinterCLOptions();
|
||||
// registerMLIRContextCLOptions();
|
||||
// registerPassManagerCLOptions();
|
||||
// registerDefaultTimingManagerCLOptions();
|
||||
DebugCounter::registerCLOptions();
|
||||
|
||||
|
||||
mlir::registerAllPasses();
|
||||
mlir::mhlo::registerAllMhloPasses();
|
||||
mlir::lmhlo::registerAllLmhloPasses();
|
||||
mlir::disc_ral::registerAllDiscRalPasses();
|
||||
|
||||
mlir::DialectRegistry registry;
|
||||
mlir::registerAllToLLVMIRTranslations(registry);
|
||||
mlir::registerAllDialects(registry);
|
||||
registry.insert<mlir::mhlo::MhloDialect>();
|
||||
registry.insert<mlir::chlo::HloClientDialect>();
|
||||
|
@ -37,7 +108,178 @@ int main(int argc, char **argv) {
|
|||
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||
registry.insert<mlir::disc_ral::RalDialect>();
|
||||
|
||||
return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
|
||||
registry,
|
||||
/*preloadDialectsInContext=*/false));
|
||||
|
||||
// failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
|
||||
// registry,
|
||||
// /*preloadDialectsInContext=*/false));
|
||||
// return 0;
|
||||
|
||||
|
||||
|
||||
std::string errorMessage;
|
||||
|
||||
// auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir", &errorMessage);
|
||||
auto file = mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage);
|
||||
std::cout<<"errorMessage:" <<errorMessage <<std::endl;
|
||||
|
||||
SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
|
||||
|
||||
|
||||
SmallVector<const llvm::PassInfo *, 4> passes;
|
||||
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
|
||||
auto tmOrError = tmBuilderOrError->createTargetMachine();
|
||||
|
||||
auto transformer = mlir::makeLLVMPassesTransformer(
|
||||
passes, 0, /*targetMachine=*/tmOrError->get(), 0);
|
||||
|
||||
|
||||
|
||||
MLIRContext context(registry);
|
||||
context.loadAllAvailableDialects();
|
||||
OwningModuleRef module(parseSourceFile(sourceMgr, &context));
|
||||
module->dump();
|
||||
|
||||
auto errorHandler = [&](const Twine &msg) {
|
||||
// emitError(UnknownLoc::get(context)) << msg;
|
||||
return failure();
|
||||
};
|
||||
|
||||
|
||||
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
|
||||
// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \
|
||||
// RUN: -buffer-deallocation -canonicalize -cse \
|
||||
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
|
||||
// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \
|
||||
// RUN: -convert-std-to-llvm
|
||||
|
||||
mlir::PassManager pm(&context, OpPassManager::Nesting::Implicit);
|
||||
// pm.enableVerifier(verifyPasses);
|
||||
applyPassManagerCLOptions(pm);
|
||||
// pm.enableTiming(timing);
|
||||
|
||||
pm.addPass(mlir::mhlo::createChloLegalizeToHloPass());
|
||||
pm.addPass(mlir::mhlo::createLegalizeToLhloPass());
|
||||
|
||||
pm.addPass(mlir::createBufferHoistingPass());
|
||||
pm.addPass(mlir::createBufferDeallocationPass());
|
||||
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
|
||||
pm.addPass(mlir::lmhlo::createLegalizeLhloToLinalgPass());
|
||||
pm.addPass(mlir::lmhlo::createLhloFuseLinalgPass());
|
||||
pm.addPass(mlir::createConvertLinalgToLoopsPass());
|
||||
|
||||
pm.addPass(mlir::createLowerAffinePass());
|
||||
pm.addPass(mlir::createLowerToCFGPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
pm.addPass(mlir::createLowerToLLVMPass());
|
||||
|
||||
pm.run(*module);
|
||||
module->dump();
|
||||
|
||||
|
||||
|
||||
std::cout<<"DEBUG load module success"<<std::endl;
|
||||
|
||||
llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default;
|
||||
|
||||
// If shared library implements custom mlir-runner library init and destroy
|
||||
// functions, we'll use them to register the library with the execution
|
||||
// engine. Otherwise we'll pass library directly to the execution engine.
|
||||
SmallVector<SmallString<256>, 4> libPaths;
|
||||
|
||||
// Use absolute library path so that gdb can find the symbol table.
|
||||
|
||||
std::list<std::string> sharedlib;
|
||||
sharedlib.push_back("/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git");
|
||||
|
||||
transform(
|
||||
sharedlib,
|
||||
std::back_inserter(libPaths),
|
||||
[](std::string libPath) {
|
||||
SmallString<256> absPath(libPath.begin(), libPath.end());
|
||||
cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
|
||||
return absPath;
|
||||
});
|
||||
|
||||
// Libraries that we'll pass to the ExecutionEngine for loading.
|
||||
SmallVector<StringRef, 4> executionEngineLibs;
|
||||
|
||||
using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
|
||||
using MlirRunnerDestroyFn = void (*)();
|
||||
|
||||
llvm::StringMap<void *> exportSymbols;
|
||||
SmallVector<MlirRunnerDestroyFn> destroyFns;
|
||||
|
||||
|
||||
// Handle libraries that do support mlir-runner init/destroy callbacks.
|
||||
for (auto &libPath : libPaths) {
|
||||
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
|
||||
void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
|
||||
void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
|
||||
|
||||
// Library does not support mlir runner, load it with ExecutionEngine.
|
||||
if (!initSym || !destroySim) {
|
||||
executionEngineLibs.push_back(libPath);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
|
||||
initFn(exportSymbols);
|
||||
|
||||
auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
|
||||
destroyFns.push_back(destroyFn);
|
||||
}
|
||||
|
||||
// Build a runtime symbol map from the config and exported symbols.
|
||||
auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
|
||||
auto symbolMap = llvm::orc::SymbolMap();
|
||||
for (auto &exportSymbol : exportSymbols)
|
||||
symbolMap[interner(exportSymbol.getKey())] =
|
||||
llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
|
||||
return symbolMap;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
auto expectedEngine = mlir::ExecutionEngine::create(
|
||||
module.get(), nullptr, transformer, jitCodeGenOptLevel,
|
||||
executionEngineLibs);
|
||||
// if (!expectedEngine)
|
||||
// return expectedEngine.takeError();
|
||||
|
||||
|
||||
|
||||
auto engine = std::move(*expectedEngine);
|
||||
engine->registerSymbols(runtimeSymbolMap);
|
||||
|
||||
auto expectedFPtr = engine->lookup("main");
|
||||
if (!expectedFPtr)
|
||||
// std::cout<<"expectedFPtr "<<expectedFPtr.takeError()<<std::endl;
|
||||
return 1;
|
||||
|
||||
// if (options.dumpObjectFile)
|
||||
// engine->dumpToObjectFile("a.o");
|
||||
|
||||
int res;
|
||||
struct {
|
||||
void *data;
|
||||
} data;
|
||||
data.data = &res;
|
||||
|
||||
|
||||
void (*fptr)(void **) = *expectedFPtr;
|
||||
(*fptr)((void **)&data);
|
||||
|
||||
|
||||
std::cout<<"result.data "<<res<<std::endl;
|
||||
|
||||
// Run all dynamic library destroy callbacks to prepare for the shutdown.
|
||||
llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue