diff --git a/BUILD b/BUILD index a5c2329..adb73c7 100644 --- a/BUILD +++ b/BUILD @@ -1373,12 +1373,20 @@ cc_binary( ":hlo", ":lhlo", ":lhlo_gpu", + "@llvm-project//llvm:AllTargetsAsmParsers", + "@llvm-project//llvm:AllTargetsCodeGens", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:Option", + "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:MlirJitRunner", ], ) diff --git a/tests/test.mlir b/tests/test.mlir new file mode 100644 index 0000000..d950f1a --- /dev/null +++ b/tests/test.mlir @@ -0,0 +1,39 @@ +func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } + +func @main() { + %c0 = constant 0 : index + %c1 = constant 1 : index + + // Initialize input. + %input = memref.alloc() : memref<2x3xf32> + %dim_x = memref.dim %input, %c0 : memref<2x3xf32> + %dim_y = memref.dim %input, %c1 : memref<2x3xf32> + scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { + %i_i64 = index_cast %i : index to i64 + %i_f32 = sitofp %i_i64 : i64 to f32 + memref.store %i_f32, %input[%i, %j] : memref<2x3xf32> + } + %unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32> + call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] + // CHECK: [0, 0, 0] + // CHECK: [1, 1, 1] + + %in = memref.tensor_load %input : memref<2x3xf32> + + %add = "mhlo.add"(%in, %in) {name = "add.3"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + + %output = memref.buffer_cast %add : memref<2x3xf32> + %unranked_output = memref.cast %output : memref<2x3xf32> to memref<*xf32> + + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] + // CHECK: [0, 0, 0] + // CHECK: [2, 2, 2] + + return +} + +// ./mlir-hlo-opt -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -canonicalize -cse -lhlo-legalize-to-linalg -convert-linalg-to-loops -lower-affine -convert-scf-to-std -convert-std-to-llvm ../../../tests/test.mlir > a.mlir +// /root/mlir-hlo/llvm-build/bin/mlir-cpu-runner --entry-point-result=void -shared-libs=/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git a.mlir > b.mlir +// /root/mlir-hlo/llvm-build/bin/FileCheck --input-file b.mlir ../../../tests/test.mlir diff --git a/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tools/mlir-hlo-opt/mlir-hlo-opt.cpp index 9a9cee8..777b5db 100644 --- a/tools/mlir-hlo-opt/mlir-hlo-opt.cpp +++ b/tools/mlir-hlo-opt/mlir-hlo-opt.cpp @@ -13,23 +13,94 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "iostream" + +#include "mlir/Support/MlirOptMain.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/DebugCounter.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/Timing.h" +#include "mlir/Support/ToolUtilities.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/StringSaver.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/DynamicLibrary.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/ExecutionEngine/JitRunner.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" + + + +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" + + + #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" -#include "mlir/InitAllDialects.h" -#include "mlir/InitAllPasses.h" -#include "mlir/Support/MlirOptMain.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" + + +#include "mlir/ExecutionEngine/ExecutionEngine.h" +// #include "mlir/IR/BuiltinTypes.h" + + +#include "llvm/Support/TargetSelect.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/Mangling.h" + + +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassNameParser.h" + + +using namespace mlir; +using namespace llvm; int main(int argc, char **argv) { + + llvm::InitLLVM y(argc, argv); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + mlir::initializeLLVMPasses(); + + // Register any command line options. + // registerAsmPrinterCLOptions(); + // registerMLIRContextCLOptions(); + // registerPassManagerCLOptions(); + // registerDefaultTimingManagerCLOptions(); + DebugCounter::registerCLOptions(); + + mlir::registerAllPasses(); mlir::mhlo::registerAllMhloPasses(); mlir::lmhlo::registerAllLmhloPasses(); mlir::disc_ral::registerAllDiscRalPasses(); mlir::DialectRegistry registry; + mlir::registerAllToLLVMIRTranslations(registry); mlir::registerAllDialects(registry); registry.insert(); registry.insert(); @@ -37,7 +108,178 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); - return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", - registry, - /*preloadDialectsInContext=*/false)); + + // failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", + // registry, + // /*preloadDialectsInContext=*/false)); + // return 0; + + + + std::string errorMessage; + + // auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir", &errorMessage); + auto file = mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage); + std::cout<<"errorMessage:" < passes; + auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); + auto tmOrError = tmBuilderOrError->createTargetMachine(); + + auto transformer = mlir::makeLLVMPassesTransformer( + passes, 0, /*targetMachine=*/tmOrError->get(), 0); + + + + MLIRContext context(registry); + context.loadAllAvailableDialects(); + OwningModuleRef module(parseSourceFile(sourceMgr, &context)); + module->dump(); + + auto errorHandler = [&](const Twine &msg) { + // emitError(UnknownLoc::get(context)) << msg; + return failure(); + }; + + +// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \ +// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \ +// RUN: -buffer-deallocation -canonicalize -cse \ +// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \ +// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \ +// RUN: -convert-std-to-llvm + + mlir::PassManager pm(&context, OpPassManager::Nesting::Implicit); + // pm.enableVerifier(verifyPasses); + applyPassManagerCLOptions(pm); + // pm.enableTiming(timing); + + pm.addPass(mlir::mhlo::createChloLegalizeToHloPass()); + pm.addPass(mlir::mhlo::createLegalizeToLhloPass()); + + pm.addPass(mlir::createBufferHoistingPass()); + pm.addPass(mlir::createBufferDeallocationPass()); + + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + + pm.addPass(mlir::lmhlo::createLegalizeLhloToLinalgPass()); + pm.addPass(mlir::lmhlo::createLhloFuseLinalgPass()); + pm.addPass(mlir::createConvertLinalgToLoopsPass()); + + pm.addPass(mlir::createLowerAffinePass()); + pm.addPass(mlir::createLowerToCFGPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createLowerToLLVMPass()); + + pm.run(*module); + module->dump(); + + + + std::cout<<"DEBUG load module success"<, 4> libPaths; + + // Use absolute library path so that gdb can find the symbol table. + + std::list sharedlib; + sharedlib.push_back("/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git"); + + transform( + sharedlib, + std::back_inserter(libPaths), + [](std::string libPath) { + SmallString<256> absPath(libPath.begin(), libPath.end()); + cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath))); + return absPath; + }); + + // Libraries that we'll pass to the ExecutionEngine for loading. + SmallVector executionEngineLibs; + + using MlirRunnerInitFn = void (*)(llvm::StringMap &); + using MlirRunnerDestroyFn = void (*)(); + + llvm::StringMap exportSymbols; + SmallVector destroyFns; + + + // Handle libraries that do support mlir-runner init/destroy callbacks. + for (auto &libPath : libPaths) { + auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str()); + void *initSym = lib.getAddressOfSymbol("__mlir_runner_init"); + void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy"); + + // Library does not support mlir runner, load it with ExecutionEngine. + if (!initSym || !destroySim) { + executionEngineLibs.push_back(libPath); + continue; + } + + auto initFn = reinterpret_cast(initSym); + initFn(exportSymbols); + + auto destroyFn = reinterpret_cast(destroySim); + destroyFns.push_back(destroyFn); + } + + // Build a runtime symbol map from the config and exported symbols. + auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) { + auto symbolMap = llvm::orc::SymbolMap(); + for (auto &exportSymbol : exportSymbols) + symbolMap[interner(exportSymbol.getKey())] = + llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue()); + return symbolMap; + }; + + + + + auto expectedEngine = mlir::ExecutionEngine::create( + module.get(), nullptr, transformer, jitCodeGenOptLevel, + executionEngineLibs); + // if (!expectedEngine) + // return expectedEngine.takeError(); + + + + auto engine = std::move(*expectedEngine); + engine->registerSymbols(runtimeSymbolMap); + + auto expectedFPtr = engine->lookup("main"); + if (!expectedFPtr) + // std::cout<<"expectedFPtr "<dumpToObjectFile("a.o"); + + int res; + struct { + void *data; + } data; + data.data = &res; + + + void (*fptr)(void **) = *expectedFPtr; + (*fptr)((void **)&data); + + + std::cout<<"result.data "<