Add test for checking lowering of Add op to KRNL IR (#385)

* Add test for checking lowering of Add op to KRNL IR.

* Add test file.
This commit is contained in:
GHEORGHE-TEOD BERCEA 2019-11-26 19:29:18 -05:00 committed by Tian Jin
parent b46f965715
commit 652ce4b7d4
7 changed files with 53 additions and 6 deletions

View File

@ -67,10 +67,10 @@ find_mlir_lib(MLIRStandardOps)
find_mlir_lib(MLIRStandardToLLVM) find_mlir_lib(MLIRStandardToLLVM)
find_mlir_lib(MLIRTargetLLVMIR) find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransforms) find_mlir_lib(MLIRTransforms)
find_mlir_lib(MLIRTransforms)
find_mlir_lib(MLIRTransformUtils) find_mlir_lib(MLIRTransformUtils)
find_mlir_lib(MLIRSupport) find_mlir_lib(MLIRSupport)
find_mlir_lib(MLIROptMain) find_mlir_lib(MLIROptMain)
find_mlir_lib(MLIRVectorOps)
find_mlir_lib(LLVMCore) find_mlir_lib(LLVMCore)
find_mlir_lib(LLVMSupport) find_mlir_lib(LLVMSupport)
@ -132,7 +132,9 @@ set(MLIRWholeArchiveLibs
MLIRLLVMIR MLIRLLVMIR
MLIRStandardOps MLIRStandardOps
MLIRStandardToLLVM MLIRStandardToLLVM
MLIRLoopToStandard) MLIRTransforms
MLIRLoopToStandard
MLIRVectorOps)
function(whole_archive_link target lib_dir) function(whole_archive_link target lib_dir)
get_property(link_flags TARGET ${target} PROPERTY LINK_FLAGS) get_property(link_flags TARGET ${target} PROPERTY LINK_FLAGS)

View File

@ -45,8 +45,6 @@ target_link_libraries(compiler
${MLIRLibs} ${MLIRLibs}
curses) curses)
add_subdirectory(tool)
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td) set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls) onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs) onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
@ -69,4 +67,19 @@ onnf_tablegen(krnl.hpp.inc -gen-op-decls)
onnf_tablegen(krnl.cpp.inc -gen-op-defs) onnf_tablegen(krnl.cpp.inc -gen-op-defs)
add_public_tablegen_target(gen_krnl_ops) add_public_tablegen_target(gen_krnl_ops)
add_dependencies(compiler gen_krnl_ops) add_dependencies(compiler gen_krnl_ops)
add_dependencies(onnf-opt gen_krnl_ops)
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
target_include_directories(onnf_shape_inference
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT})
target_link_libraries(onnf_shape_inference ${MLIRLibs})
add_dependencies(onnf_shape_inference gen_krnl_ops)
add_library(onnf_lower_frontend pass/lower_frontend_to_krnl.cpp)
target_include_directories(onnf_lower_frontend
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT})
target_link_libraries(onnf_lower_frontend ${MLIRLibs})
add_dependencies(onnf_lower_frontend gen_krnl_ops)
add_subdirectory(tool)

View File

@ -280,3 +280,6 @@ void FrontendToKrnlLoweringPass::runOnModule() {
std::unique_ptr<Pass> mlir::createLowerToKrnlPass() { std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
return std::make_unique<FrontendToKrnlLoweringPass>(); return std::make_unique<FrontendToKrnlLoweringPass>();
} }
static PassRegistration<FrontendToKrnlLoweringPass> pass(
"lower-frontend", "Lower frontend ops to Krnl dialect.");

View File

@ -17,7 +17,7 @@ class Pass;
std::unique_ptr<Pass> createShapeInferencePass(); std::unique_ptr<Pass> createShapeInferencePass();
/// Add pass for lowering to Krnl IR. /// Pass for lowering frontend dialects to Krnl IR dialect.
std::unique_ptr<mlir::Pass> createLowerToKrnlPass(); std::unique_ptr<mlir::Pass> createLowerToKrnlPass();
// TODO: Add pass for lowering to LLVM IR. // TODO: Add pass for lowering to LLVM IR.

View File

@ -105,3 +105,6 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() { std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>(); return std::make_unique<ShapeInferencePass>();
} }
static PassRegistration<ShapeInferencePass> pass(
"shape-inference", "Shape inference for frontend dialects.");

View File

@ -1,10 +1,13 @@
add_executable(onnf-opt onnf_opt.cpp) add_executable(onnf-opt onnf_opt.cpp)
add_dependencies(onnf-opt gen_krnl_ops)
target_include_directories(onnf-opt PRIVATE ${ONNF_SRC_ROOT}) target_include_directories(onnf-opt PRIVATE ${ONNF_SRC_ROOT})
target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT}) target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT})
target_link_libraries(onnf-opt compiler ${MLIRLibs}) target_link_libraries(onnf-opt compiler ${MLIRLibs})
whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs}) whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs})
whole_archive_link_onnf(onnf-opt onnf_lower_frontend)
whole_archive_link_onnf(onnf-opt onnf_shape_inference)
# TODO: need to investigate how to whole-archive link compiler pass to onnf-opt. # TODO: need to investigate how to whole-archive link compiler pass to onnf-opt.
target_link_libraries(onnf-opt compiler) target_link_libraries(onnf-opt compiler)

View File

@ -0,0 +1,23 @@
// RUN: onnf-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s
module {
func @test_sigmoid(%a1 : tensor<?x10xf32>, %a2 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Add"(%a1, %a2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
}
}
// CHECK: func @test_sigmoid([[ARG0:%.+]]: memref<?x10xf32>, [[ARG1:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
// CHECK: [[DIM_0:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[ARG0]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>