Add test for checking lowering of Add op to KRNL IR (#385)
* Add test for checking lowering of Add op to KRNL IR. * Add test file.
This commit is contained in:
parent
b46f965715
commit
652ce4b7d4
|
@ -67,10 +67,10 @@ find_mlir_lib(MLIRStandardOps)
|
||||||
find_mlir_lib(MLIRStandardToLLVM)
|
find_mlir_lib(MLIRStandardToLLVM)
|
||||||
find_mlir_lib(MLIRTargetLLVMIR)
|
find_mlir_lib(MLIRTargetLLVMIR)
|
||||||
find_mlir_lib(MLIRTransforms)
|
find_mlir_lib(MLIRTransforms)
|
||||||
find_mlir_lib(MLIRTransforms)
|
|
||||||
find_mlir_lib(MLIRTransformUtils)
|
find_mlir_lib(MLIRTransformUtils)
|
||||||
find_mlir_lib(MLIRSupport)
|
find_mlir_lib(MLIRSupport)
|
||||||
find_mlir_lib(MLIROptMain)
|
find_mlir_lib(MLIROptMain)
|
||||||
|
find_mlir_lib(MLIRVectorOps)
|
||||||
|
|
||||||
find_mlir_lib(LLVMCore)
|
find_mlir_lib(LLVMCore)
|
||||||
find_mlir_lib(LLVMSupport)
|
find_mlir_lib(LLVMSupport)
|
||||||
|
@ -132,7 +132,9 @@ set(MLIRWholeArchiveLibs
|
||||||
MLIRLLVMIR
|
MLIRLLVMIR
|
||||||
MLIRStandardOps
|
MLIRStandardOps
|
||||||
MLIRStandardToLLVM
|
MLIRStandardToLLVM
|
||||||
MLIRLoopToStandard)
|
MLIRTransforms
|
||||||
|
MLIRLoopToStandard
|
||||||
|
MLIRVectorOps)
|
||||||
|
|
||||||
function(whole_archive_link target lib_dir)
|
function(whole_archive_link target lib_dir)
|
||||||
get_property(link_flags TARGET ${target} PROPERTY LINK_FLAGS)
|
get_property(link_flags TARGET ${target} PROPERTY LINK_FLAGS)
|
||||||
|
|
|
@ -45,8 +45,6 @@ target_link_libraries(compiler
|
||||||
${MLIRLibs}
|
${MLIRLibs}
|
||||||
curses)
|
curses)
|
||||||
|
|
||||||
add_subdirectory(tool)
|
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
|
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
|
||||||
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
|
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
|
||||||
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
||||||
|
@ -69,4 +67,19 @@ onnf_tablegen(krnl.hpp.inc -gen-op-decls)
|
||||||
onnf_tablegen(krnl.cpp.inc -gen-op-defs)
|
onnf_tablegen(krnl.cpp.inc -gen-op-defs)
|
||||||
add_public_tablegen_target(gen_krnl_ops)
|
add_public_tablegen_target(gen_krnl_ops)
|
||||||
add_dependencies(compiler gen_krnl_ops)
|
add_dependencies(compiler gen_krnl_ops)
|
||||||
add_dependencies(onnf-opt gen_krnl_ops)
|
|
||||||
|
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
|
||||||
|
target_include_directories(onnf_shape_inference
|
||||||
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
|
${ONNF_SRC_ROOT})
|
||||||
|
target_link_libraries(onnf_shape_inference ${MLIRLibs})
|
||||||
|
add_dependencies(onnf_shape_inference gen_krnl_ops)
|
||||||
|
|
||||||
|
add_library(onnf_lower_frontend pass/lower_frontend_to_krnl.cpp)
|
||||||
|
target_include_directories(onnf_lower_frontend
|
||||||
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
|
${ONNF_SRC_ROOT})
|
||||||
|
target_link_libraries(onnf_lower_frontend ${MLIRLibs})
|
||||||
|
add_dependencies(onnf_lower_frontend gen_krnl_ops)
|
||||||
|
|
||||||
|
add_subdirectory(tool)
|
||||||
|
|
|
@ -280,3 +280,6 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
|
std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
|
||||||
return std::make_unique<FrontendToKrnlLoweringPass>();
|
return std::make_unique<FrontendToKrnlLoweringPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PassRegistration<FrontendToKrnlLoweringPass> pass(
|
||||||
|
"lower-frontend", "Lower frontend ops to Krnl dialect.");
|
||||||
|
|
|
@ -17,7 +17,7 @@ class Pass;
|
||||||
|
|
||||||
std::unique_ptr<Pass> createShapeInferencePass();
|
std::unique_ptr<Pass> createShapeInferencePass();
|
||||||
|
|
||||||
/// Add pass for lowering to Krnl IR.
|
/// Pass for lowering frontend dialects to Krnl IR dialect.
|
||||||
std::unique_ptr<mlir::Pass> createLowerToKrnlPass();
|
std::unique_ptr<mlir::Pass> createLowerToKrnlPass();
|
||||||
|
|
||||||
// TODO: Add pass for lowering to LLVM IR.
|
// TODO: Add pass for lowering to LLVM IR.
|
||||||
|
|
|
@ -105,3 +105,6 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
|
std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
|
||||||
return std::make_unique<ShapeInferencePass>();
|
return std::make_unique<ShapeInferencePass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PassRegistration<ShapeInferencePass> pass(
|
||||||
|
"shape-inference", "Shape inference for frontend dialects.");
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
add_executable(onnf-opt onnf_opt.cpp)
|
add_executable(onnf-opt onnf_opt.cpp)
|
||||||
|
add_dependencies(onnf-opt gen_krnl_ops)
|
||||||
|
|
||||||
target_include_directories(onnf-opt PRIVATE ${ONNF_SRC_ROOT})
|
target_include_directories(onnf-opt PRIVATE ${ONNF_SRC_ROOT})
|
||||||
target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT})
|
target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT})
|
||||||
|
|
||||||
target_link_libraries(onnf-opt compiler ${MLIRLibs})
|
target_link_libraries(onnf-opt compiler ${MLIRLibs})
|
||||||
whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs})
|
whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs})
|
||||||
|
whole_archive_link_onnf(onnf-opt onnf_lower_frontend)
|
||||||
|
whole_archive_link_onnf(onnf-opt onnf_shape_inference)
|
||||||
|
|
||||||
# TODO: need to investigate how to whole-archive link compiler pass to onnf-opt.
|
# TODO: need to investigate how to whole-archive link compiler pass to onnf-opt.
|
||||||
target_link_libraries(onnf-opt compiler)
|
target_link_libraries(onnf-opt compiler)
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
// RUN: onnf-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @test_sigmoid(%a1 : tensor<?x10xf32>, %a2 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Add"(%a1, %a2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @test_sigmoid([[ARG0:%.+]]: memref<?x10xf32>, [[ARG1:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load [[ARG0]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
Loading…
Reference in New Issue