From 652ce4b7d47c44a21d0f27623b8b96fc8ada32fc Mon Sep 17 00:00:00 2001 From: GHEORGHE-TEOD BERCEA Date: Tue, 26 Nov 2019 19:29:18 -0500 Subject: [PATCH] Add test for checking lowering of Add op to KRNL IR (#385) * Add test for checking lowering of Add op to KRNL IR. * Add test file. --- MLIR.cmake | 6 +++-- src/compiler/CMakeLists.txt | 19 +++++++++++++--- src/compiler/pass/lower_frontend_to_krnl.cpp | 3 +++ src/compiler/pass/passes.hpp | 2 +- src/compiler/pass/shape_inference_pass.cpp | 3 +++ src/compiler/tool/onnf_opt/CMakeLists.txt | 3 +++ test/mlir/onnx/onnx_lowering.mlir | 23 ++++++++++++++++++++ 7 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 test/mlir/onnx/onnx_lowering.mlir diff --git a/MLIR.cmake b/MLIR.cmake index a01cbb0..af9c75e 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -67,10 +67,10 @@ find_mlir_lib(MLIRStandardOps) find_mlir_lib(MLIRStandardToLLVM) find_mlir_lib(MLIRTargetLLVMIR) find_mlir_lib(MLIRTransforms) -find_mlir_lib(MLIRTransforms) find_mlir_lib(MLIRTransformUtils) find_mlir_lib(MLIRSupport) find_mlir_lib(MLIROptMain) +find_mlir_lib(MLIRVectorOps) find_mlir_lib(LLVMCore) find_mlir_lib(LLVMSupport) @@ -132,7 +132,9 @@ set(MLIRWholeArchiveLibs MLIRLLVMIR MLIRStandardOps MLIRStandardToLLVM - MLIRLoopToStandard) + MLIRTransforms + MLIRLoopToStandard + MLIRVectorOps) function(whole_archive_link target lib_dir) get_property(link_flags TARGET ${target} PROPERTY LINK_FLAGS) diff --git a/src/compiler/CMakeLists.txt b/src/compiler/CMakeLists.txt index 91dc4e0..480e2e0 100644 --- a/src/compiler/CMakeLists.txt +++ b/src/compiler/CMakeLists.txt @@ -45,8 +45,6 @@ target_link_libraries(compiler ${MLIRLibs} curses) -add_subdirectory(tool) - set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td) onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls) onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs) @@ -69,4 +67,19 @@ onnf_tablegen(krnl.hpp.inc -gen-op-decls) onnf_tablegen(krnl.cpp.inc -gen-op-defs) add_public_tablegen_target(gen_krnl_ops) add_dependencies(compiler gen_krnl_ops) -add_dependencies(onnf-opt gen_krnl_ops) + +add_library(onnf_shape_inference pass/shape_inference_pass.cpp) +target_include_directories(onnf_shape_inference + PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} + ${ONNF_SRC_ROOT}) +target_link_libraries(onnf_shape_inference ${MLIRLibs}) +add_dependencies(onnf_shape_inference gen_krnl_ops) + +add_library(onnf_lower_frontend pass/lower_frontend_to_krnl.cpp) +target_include_directories(onnf_lower_frontend + PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} + ${ONNF_SRC_ROOT}) +target_link_libraries(onnf_lower_frontend ${MLIRLibs}) +add_dependencies(onnf_lower_frontend gen_krnl_ops) + +add_subdirectory(tool) diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp index 78097c1..7401037 100644 --- a/src/compiler/pass/lower_frontend_to_krnl.cpp +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -280,3 +280,6 @@ void FrontendToKrnlLoweringPass::runOnModule() { std::unique_ptr mlir::createLowerToKrnlPass() { return std::make_unique(); } + +static PassRegistration pass( + "lower-frontend", "Lower frontend ops to Krnl dialect."); diff --git a/src/compiler/pass/passes.hpp b/src/compiler/pass/passes.hpp index ee10dd0..a268021 100644 --- a/src/compiler/pass/passes.hpp +++ b/src/compiler/pass/passes.hpp @@ -17,7 +17,7 @@ class Pass; std::unique_ptr createShapeInferencePass(); -/// Add pass for lowering to Krnl IR. +/// Pass for lowering frontend dialects to Krnl IR dialect. std::unique_ptr createLowerToKrnlPass(); // TODO: Add pass for lowering to LLVM IR. diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp index 94e5389..fc3ed1f 100644 --- a/src/compiler/pass/shape_inference_pass.cpp +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -105,3 +105,6 @@ class ShapeInferencePass : public mlir::FunctionPass { std::unique_ptr mlir::createShapeInferencePass() { return std::make_unique(); } + +static PassRegistration pass( + "shape-inference", "Shape inference for frontend dialects."); diff --git a/src/compiler/tool/onnf_opt/CMakeLists.txt b/src/compiler/tool/onnf_opt/CMakeLists.txt index 3ec11da..669e999 100644 --- a/src/compiler/tool/onnf_opt/CMakeLists.txt +++ b/src/compiler/tool/onnf_opt/CMakeLists.txt @@ -1,10 +1,13 @@ add_executable(onnf-opt onnf_opt.cpp) +add_dependencies(onnf-opt gen_krnl_ops) target_include_directories(onnf-opt PRIVATE ${ONNF_SRC_ROOT}) target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT}) target_link_libraries(onnf-opt compiler ${MLIRLibs}) whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs}) +whole_archive_link_onnf(onnf-opt onnf_lower_frontend) +whole_archive_link_onnf(onnf-opt onnf_shape_inference) # TODO: need to investigate how to whole-archive link compiler pass to onnf-opt. target_link_libraries(onnf-opt compiler) diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir new file mode 100644 index 0000000..9edbf58 --- /dev/null +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -0,0 +1,23 @@ +// RUN: onnf-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s + +module { + func @test_sigmoid(%a1 : tensor, %a2 : tensor) -> tensor<*xf32> { + %0 = "onnx.Add"(%a1, %a2) : (tensor, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + } +} + +// CHECK: func @test_sigmoid([[ARG0:%.+]]: memref, [[ARG1:%.+]]: memref) -> memref { +// CHECK: [[DIM_0:%.+]] = dim [[ARG0]], 0 : memref +// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref +// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 +// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { +// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 +// CHECK: } : () -> (!krnl.loop, !krnl.loop) +// CHECK: [[DIM_2:%.+]] = dim [[ARG0]], 0 : memref +// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { +// CHECK: [[LOAD1:%.+]] = load [[ARG0]][%arg2, %arg3] : memref +// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref +// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 +// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref +// CHECK: return [[RES]] : memref