[MLIR] compartmentalize build script (#369)

* compartmentalize build script, temporarily remove dependency of onnf_opt on helper.cpp

* fix test includes

* fix op directory include

* compiler -> op

* compiler test depends on boost system

* fix function name

* specify libcompiler dependencies

* let cmake take care of transitive dependencies

* remove unnecessary includes

* use ONNF_SRC_ROOT and ONNF_BIN_ROOT

* allow whole-archive linked libraries to be appended

* [MLIR] Support filecheck (#371)

* support lit+FileCheck

* add lit into build script

* format MLIR.cmake

* format cmake

* [MLIR] Remove input/output ops (#372)

* remove input/output ops

* get output tensor type from symbol table
This commit is contained in:
Tian Jin 2019-11-18 19:37:58 -05:00 committed by Tian Jin
parent dc36fd416b
commit d01ac7732f
15 changed files with 282 additions and 117 deletions

View File

@ -42,3 +42,6 @@ cmake3 -DONNF_ENABLE_MODEL_TEST_CPP=ON \
# Build and test: # Build and test:
make -j "$(nproc)" install make -j "$(nproc)" install
OMP_NUM_THREADS=20 OMP_THREAD_LIMIT=20 ctest3 -j "$(nproc)" OMP_NUM_THREADS=20 OMP_THREAD_LIMIT=20 ctest3 -j "$(nproc)"
# Run lit+FileCheck tests:
make check-mlir-lit

View File

@ -10,7 +10,13 @@ project(onnf)
set(CMAKE_CXX_FLAGS_DEBUG "-g") set(CMAKE_CXX_FLAGS_DEBUG "-g")
set(CMAKE_CXX_FLAGS_RELEASE "-O2 -DNDEBUG") set(CMAKE_CXX_FLAGS_RELEASE "-O2 -DNDEBUG")
set(ONNF_ROOT "${CMAKE_CURRENT_SOURCE_DIR}")
set(ONNF_SRC_ROOT "${CMAKE_CURRENT_SOURCE_DIR}")
set(ONNF_BIN_ROOT "${CMAKE_CURRENT_BINARY_DIR}")
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
#TODO(eventually enable the following)
#set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
add_subdirectory(third_party/onnx) add_subdirectory(third_party/onnx)
add_subdirectory(third_party/benchmark) add_subdirectory(third_party/benchmark)
@ -36,7 +42,6 @@ if(Boost_FOUND)
endif() endif()
include(MLIR.cmake) include(MLIR.cmake)
add_subdirectory(src/builder) add_subdirectory(src/builder)
add_subdirectory(src/compiler) add_subdirectory(src/compiler)
add_subdirectory(src) add_subdirectory(src)

View File

@ -1,14 +1,11 @@
# Flags to link with LLVM/MLIR libraries
# Path to LLVM source folder. # Path to LLVM source folder.
if(DEFINED ENV{LLVM_SRC}) if(DEFINED ENV{LLVM_SRC})
set(LLVM_SRC $ENV{LLVM_SRC}) set(LLVM_SRC $ENV{LLVM_SRC})
if(EXISTS ${LLVM_SRC}) if(EXISTS ${LLVM_SRC})
message(STATUS "LLVM_SRC " ${LLVM_SRC}) message(STATUS "LLVM_SRC " ${LLVM_SRC})
else() else()
message( message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: "
FATAL_ERROR "The path specified by LLVM_SRC does not exist: " ${LLVM_SRC})
${LLVM_SRC})
endif() endif()
else() else()
message(FATAL_ERROR "env variable LLVM_SRC not set") message(FATAL_ERROR "env variable LLVM_SRC not set")
@ -20,9 +17,8 @@ if(DEFINED ENV{LLVM_BUILD})
if(EXISTS ${LLVM_BUILD}) if(EXISTS ${LLVM_BUILD})
message(STATUS "LLVM_BUILD " ${LLVM_BUILD}) message(STATUS "LLVM_BUILD " ${LLVM_BUILD})
else() else()
message( message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: "
FATAL_ERROR "The path specified by LLVM_BUILD does not exist: " ${LLVM_BUILD})
${LLVM_BUILD})
endif() endif()
else() else()
message(FATAL_ERROR "env variable LLVM_BUILD not set") message(FATAL_ERROR "env variable LLVM_BUILD not set")
@ -36,9 +32,16 @@ set(LLVM_SRC_INCLUDE_PATH ${LLVM_SRC}/include)
set(LLVM_BIN_INCLUDE_PATH ${LLVM_BUILD}/include) set(LLVM_BIN_INCLUDE_PATH ${LLVM_BUILD}/include)
set(MLIR_SRC_INCLUDE_PATH ${LLVM_SRC}/projects/mlir/include) set(MLIR_SRC_INCLUDE_PATH ${LLVM_SRC}/projects/mlir/include)
set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include) set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include)
set(MLIR_TOOLS_DIR ${LLVM_BUILD}/bin)
set(MLIR_INCLUDE_PATHS set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/src/compiler/tool/onnf_opt)
${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH}) set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir)
set(
MLIR_INCLUDE_PATHS
${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH}
)
include_directories(${MLIR_INCLUDE_PATHS}) include_directories(${MLIR_INCLUDE_PATHS})
find_library(MLIR_LIB_ANALYSIS find_library(MLIR_LIB_ANALYSIS
@ -112,7 +115,6 @@ set(MLIRLIBS
${MLIR_LIB_OPT_MAIN} ${MLIR_LIB_OPT_MAIN}
${MLIR_LIB_SUPPORT} ${MLIR_LIB_SUPPORT}
${MLIR_LIB_TRANSFORM_UTILS} ${MLIR_LIB_TRANSFORM_UTILS}
${MLIR_LIB_ANALYSIS} ${MLIR_LIB_ANALYSIS}
${MLIR_LIB_IR} ${MLIR_LIB_IR}
${MLIR_LIB_PARSER} ${MLIR_LIB_PARSER}
@ -123,32 +125,45 @@ set(MLIRLIBS
${MLIR_LIB_OPT_MAIN} ${MLIR_LIB_OPT_MAIN}
${MLIR_LIB_SUPPORT} ${MLIR_LIB_SUPPORT}
${MLIR_LIB_TRANSFORM_UTILS} ${MLIR_LIB_TRANSFORM_UTILS}
${LLVM_LIB_SUPPORT} ${LLVM_LIB_SUPPORT}
Threads::Threads) Threads::Threads)
function(whole_archive_link target) function(whole_archive_link target lib_dir)
get_property(link_flags TARGET ${target} PROPERTY LINK_FLAGS)
if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin") if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin")
set(link_flags "-L${LLVM_BUILD}/lib ") set(link_flags "${link_flags} -L${lib_dir} ")
FOREACH(LIB ${ARGN}) foreach(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} "-Wl,-force_load ${LLVM_BUILD}/lib/lib${LIB}.a ") string(CONCAT link_flags ${link_flags}
ENDFOREACH(LIB) "-Wl,-force_load ${lib_dir}/lib${LIB}.a ")
endforeach(LIB)
elseif(MSVC) elseif(MSVC)
FOREACH(LIB ${ARGN}) foreach(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} "/WHOLEARCHIVE:${LIB} ") string(CONCAT link_flags ${link_flags} "/WHOLEARCHIVE:${LIB} ")
ENDFOREACH(LIB) endforeach(LIB)
else() else()
set(link_flags "-L${LLVM_BUILD}/lib -Wl,--whole-archive,") set(link_flags "${link_flags} -L${lib_dir} -Wl,--whole-archive,")
FOREACH(LIB ${ARGN}) foreach(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} "-l${LIB},") string(CONCAT link_flags ${link_flags} "-l${LIB},")
ENDFOREACH(LIB) endforeach(LIB)
string(CONCAT link_flags ${link_flags} "--no-whole-archive") string(CONCAT link_flags ${link_flags} "--no-whole-archive")
endif() endif()
set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags}) set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags})
endfunction(whole_archive_link) endfunction(whole_archive_link)
# Set up TableGen environment. function(whole_archive_link_mlir target)
include(${LLVM_BUILD}/lib/cmake/llvm/TableGen.cmake) whole_archive_link(${target} ${LLVM_BUILD}/lib ${ARGN})
endfunction(whole_archive_link_mlir)
function(whole_archive_link_onnf target)
whole_archive_link(${target} ${CMAKE_BINARY_DIR}/lib ${ARGN})
endfunction(whole_archive_link_onnf)
set(LLVM_CMAKE_DIR
"${LLVM_BUILD}/lib/cmake/llvm"
CACHE PATH "Path to LLVM cmake modules")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(AddLLVM)
include(TableGen)
function(onnf_tablegen ofn) function(onnf_tablegen ofn)
tablegen(MLIR tablegen(MLIR
@ -165,7 +180,5 @@ endfunction()
# table gen utility itself can be detected and cause re-compilation of .td file. # table gen utility itself can be detected and cause re-compilation of .td file.
add_executable(mlir-tblgen IMPORTED) add_executable(mlir-tblgen IMPORTED)
set_property(TARGET mlir-tblgen set_property(TARGET mlir-tblgen
PROPERTY IMPORTED_LOCATION PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen)
${LLVM_BUILD}/bin/mlir-tblgen)
set(MLIR_TABLEGEN_EXE mlir-tblgen) set(MLIR_TABLEGEN_EXE mlir-tblgen)

View File

@ -18,6 +18,7 @@
#include <regex> #include <regex>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <map>
#include "mlir/Analysis/Verifier.h" #include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
@ -34,9 +35,10 @@
#include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "frontend_dialect_transformer.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "frontend_dialect_transformer.hpp"
namespace onnf { namespace onnf {
namespace { namespace {
@ -147,7 +149,14 @@ class FrontendGenImpl {
} }
} }
void ImportInputTensor(onnx::ValueInfoProto& input) { /*!
* Import an onnx input tensor type by determining and recording its type
* in a list of input tensor mlir types.
* @param input onnx input tensor ValueInfoProto.
* @param arg_types list of mlir types representing types of graph input.
*/
void ImportInputTensorType(const onnx::ValueInfoProto& input,
llvm::SmallVector<mlir::Type, 4>& arg_types) {
std::vector<int64_t> dims; std::vector<int64_t> dims;
auto shape_proto = input.type().tensor_type().shape(); auto shape_proto = input.type().tensor_type().shape();
auto input_tensor_legalized_name = legalize_name(input.name()); auto input_tensor_legalized_name = legalize_name(input.name());
@ -165,20 +174,28 @@ class FrontendGenImpl {
dims.push_back(-1); dims.push_back(-1);
} }
} }
if (!frontend_symbols_.ContainKey(input_tensor_legalized_name)) {
mlir::Type elementType = mlir::Type elementType =
TypeConvert(input.type().tensor_type().elem_type()); TypeConvert(input.type().tensor_type().elem_type());
llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size()); llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
auto dataType = mlir::RankedTensorType::get(llvmdimsAR, elementType); arg_types.emplace_back(
mlir::OperationState result( mlir::RankedTensorType::get(tensor_dims, elementType));
UnknownLoc(), "frontend.input " + input_tensor_legalized_name); }
result.addTypes(dataType);
auto op = builder_.createOperation(result); /*!
auto value = op->getResult(0); * Import a input tensor symbol by recording a new entry in frontend_symbols_
frontend_symbols_.AddMapping(input_tensor_legalized_name, value); * recording the mapping between legalized onnx tensor name and mlir::Value*
} else { * for further lookup in computation node importing.
// TODO Should not happen * @param input onnx input tensor ValueInfoProto.
} * @param symbol mlir input argument.
*/
void ImportInputTensorSymbol(
const onnx::ValueInfoProto& input, mlir::Value* symbol) {
auto input_tensor_legalized_name = legalize_name(input.name());
assert(
!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
"Found duplicate legalized input tensor names.");
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
} }
void ImportNode(onnx::NodeProto node) { void ImportNode(onnx::NodeProto node) {
@ -237,59 +254,80 @@ class FrontendGenImpl {
// TODO more info from node: attributes // TODO more info from node: attributes
} }
void ImportOutputTensor(onnx::ValueInfoProto& output) { /*!
* Import output tensor, by doing the following:
* - Add the type of this output tensor to a list of tensor
* types representing return types of this graph function.
* - Add this output tensor to the list of mlir::Value*
* to be returned by the function representing computation graph.
* @param output onnx output tensor ValueInfoProto.
* @param ret_types a vector of tensor types representing graph's
* output tensor types.
* @param ret_vals a vector of mlir Value* representing graph's
* output tensor.
*/
void ImportOutputTensor(const onnx::ValueInfoProto& output,
llvm::SmallVectorImpl<mlir::Type>& ret_types,
llvm::SmallVectorImpl<mlir::Value*>& ret_vals) {
auto output_tensor_legalized_name = legalize_name(output.name()); auto output_tensor_legalized_name = legalize_name(output.name());
if (frontend_symbols_.ContainKey(output_tensor_legalized_name)) { assert(
mlir::OperationState result( frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
UnknownLoc(), "frontend.output " + output_tensor_legalized_name); "Output tensor not found");
mlir::Type elementType =
TypeConvert(output.type().tensor_type().elem_type()); auto tensor_val =
result.addTypes(mlir::UnrankedTensorType::get(elementType)); frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
result.addOperands(frontend_symbols_.GetTensorByOnnxName( ret_types.emplace_back(tensor_val->getType());
output_tensor_legalized_name)); ret_vals.push_back(tensor_val);
builder_.createOperation(result);
} else {
// TODO: Why not in the symbol table? something is wrong
assert(false && "output name not found");
}
} }
void ImportGraph(onnx::GraphProto graph) { void ImportGraph(
const onnx::GraphProto& graph, const std::string& name = "main") {
// create a function for the graph // create a function for the graph
// TODO: // TODO:
// * get name and type for the function. // * get name and type for the function.
// * maintain a list of the defined graph // * maintain a list of the defined graph
llvm::SmallVector<mlir::Type, 4> ret_types;
llvm::SmallVector<mlir::Type, 4> arg_types; llvm::SmallVector<mlir::Type, 4> arg_types;
auto func_type = builder_.getFunctionType(arg_types, ret_types);
auto llvmfunction = mlir::FuncOp::create( // Import the input tensor types.
UnknownLoc(), graph.name(), func_type, /* attrs = */ {}); for (const auto& input : graph.input()) {
auto& entryBlock = *llvmfunction.addEntryBlock(); ImportInputTensorType(input, arg_types);
builder_.setInsertionPointToStart(&entryBlock); }
module_.push_back(llvmfunction);
// TODO: import the initializer // TODO: import the initializer
// auto func_type = builder_.getFunctionType(arg_types, {});
auto main_func =
mlir::FuncOp::create(UnknownLoc(), name, func_type, /* attrs = */ {});
auto& entryBlock = *main_func.addEntryBlock();
// import the input tensors builder_.setInsertionPointToStart(&entryBlock);
for (auto input : graph.input()) { module_.push_back(main_func);
ImportInputTensor(input);
for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) {
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
} }
// import nodes in the graph // import nodes in the graph
auto node = graph.node(); auto node = graph.node();
for (auto item : node) { for (const auto& item : node) {
ImportNode(item); ImportNode(item);
} }
// import the output tensors llvm::SmallVector<mlir::Type, 4> ret_types;
for (auto output : graph.output()) { llvm::SmallVector<mlir::Value*, 4> ret_vals;
ImportOutputTensor(output); // Import the output tensors
for (const auto& output : graph.output()) {
ImportOutputTensor(output, ret_types, ret_vals);
} }
// Create a return operation to return all ONNX output tensors.
builder_.create<mlir::ReturnOp>(UnknownLoc(), ret_vals);
// Update main function signature to reflect types of newly imported
// output tensors.
func_type = builder_.getFunctionType(arg_types, ret_types);
main_func.setType(func_type);
} }
}; // FrontendGenImpl class }; // FrontendGenImpl class
} // namespace } // namespace
} // namespace onnf } // namespace onnf

View File

@ -14,12 +14,21 @@ add_library(
pass/passes.hpp) pass/passes.hpp)
# Include root src directory. # Include root src directory.
target_include_directories(compiler PRIVATE ../..) target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})
target_include_directories(compiler PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(compiler PRIVATE ${CMAKE_BINARY_DIR}) # Include third-party libraries.
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT}/third_party/isl/include)
target_include_directories(compiler PRIVATE ${ONNF_BIN_ROOT}/third_party/isl/include)
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT}/third_party/Linq)
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT}/third_party/inja/src)
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT}/third_party/fmt/include)
# Include tablegen generated header files.
target_include_directories(compiler PRIVATE ${ONNF_BIN_ROOT})
find_package(Boost 1.54.0 find_package(Boost 1.54.0
COMPONENTS graph COMPONENTS
graph
program_options program_options
log_setup log_setup
log log
@ -27,20 +36,14 @@ find_package(Boost 1.54.0
filesystem filesystem
REQUIRED) REQUIRED)
# target_link_libraries(compiler isl inja ${Boost_LIBRARIES})
target_link_libraries(compiler target_link_libraries(compiler
${Boost_LIBRARIES} ${Boost_LIBRARIES}
${MLIRLIBS} curses) ${CMAKE_THREAD_LIBS_INIT}
${CMAKE_DL_LIBS}
${MLIRLIBS}
curses)
add_executable(onnf-opt add_subdirectory(tool)
tool/onnf_opt/onnf_opt.cpp)
set(LIB_LIST MLIRAffineOps MLIRLoopOps MLIRTransformUtils MLIREDSC MLIRTransforms)
whole_archive_link(onnf-opt ${LIB_LIST})
target_link_libraries(onnf-opt ${Boost_LIBRARIES} ${MLIRLIBS} curses compiler)
target_include_directories(onnf-opt PRIVATE ../..)
target_include_directories(onnf-opt PRIVATE ${CMAKE_BINARY_DIR})
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td) set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls) onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
@ -64,4 +67,4 @@ onnf_tablegen(krnl.hpp.inc -gen-op-decls)
onnf_tablegen(krnl.cpp.inc -gen-op-defs) onnf_tablegen(krnl.cpp.inc -gen-op-defs)
add_public_tablegen_target(gen_krnl_ops) add_public_tablegen_target(gen_krnl_ops)
add_dependencies(compiler gen_krnl_ops) add_dependencies(compiler gen_krnl_ops)
add_dependencies(onnf-opt gen_krnl_ops) add_dependencies(onnf-opt gen_krnl_ops)

View File

@ -21,29 +21,29 @@ class KrnlOpsDialect : public Dialect {
KrnlOpsDialect(MLIRContext* context); KrnlOpsDialect(MLIRContext* context);
static StringRef getDialectNamespace() { return "krnl"; } static StringRef getDialectNamespace() { return "krnl"; }
/// Parse a type registered to this dialect. Overriding this method is // /// Parse a type registered to this dialect. Overriding this method is
/// required for dialects that have custom types. // /// required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR. // /// Technically this is only needed to be able to round-trip to textual IR.
mlir::Type parseType( // mlir::Type parseType(
llvm::StringRef tyData, mlir::Location loc) const override { // llvm::StringRef tyData, mlir::Location loc) const override {
MLIRContext* context = getContext(); // MLIRContext* context = getContext();
//
if (tyData.consume_front("loop")) // if (tyData.consume_front("loop"))
return LoopType::get(context); // return LoopType::get(context);
else // else
return (emitError(loc, "Unexpected type: " + tyData), Type()); // return (emitError(loc, "Unexpected type: " + tyData), Type());
} // }
//
/// Print a type registered to this dialect. Overriding this method is // /// Print a type registered to this dialect. Overriding this method is
/// only required for dialects that have custom types. // /// only required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR. // /// Technically this is only needed to be able to round-trip to textual IR.
void printType(mlir::Type type, llvm::raw_ostream& os) const override { // void printType(mlir::Type type, llvm::raw_ostream& os) const override {
switch (type.getKind()) { // switch (type.getKind()) {
case KrnlTypes::Loop: // case KrnlTypes::Loop:
os << "loop"; // os << "loop";
return; // return;
} // }
} // }
}; };
#define GET_OP_CLASSES #define GET_OP_CLASSES

View File

@ -0,0 +1 @@
add_subdirectory(onnf_opt)

View File

@ -0,0 +1,16 @@
add_executable(onnf-opt onnf_opt.cpp)
target_include_directories(onnf-opt PRIVATE ${ONNF_SRC_ROOT})
target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT})
set(LIB_LIST
MLIRStandardOps
MLIRAffineOps
MLIRLoopOps
MLIRTransformUtils
MLIREDSC
MLIRTransforms)
whole_archive_link_mlir(onnf-opt ${LIB_LIST})
# TODO: need to investigate how to whole-archive link compiler pass to onnf-opt.
target_link_libraries(onnf-opt compiler)

View File

@ -19,7 +19,6 @@
#include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "src/compiler/helper.hpp"
#include "src/compiler/pass/passes.hpp" #include "src/compiler/pass/passes.hpp"
using namespace onnf; using namespace onnf;
@ -49,6 +48,7 @@ int main(int argc, char** argv) {
llvm::InitLLVM y(argc, argv); llvm::InitLLVM y(argc, argv);
mlir::registerDialect<mlir::ONNXOpsDialect>(); mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>();
// Register any pass manager command line options. // Register any pass manager command line options.
mlir::registerPassManagerCLOptions(); mlir::registerPassManagerCLOptions();
@ -59,8 +59,10 @@ int main(int argc, char** argv) {
// Set up the input file. // Set up the input file.
std::string error_message; std::string error_message;
auto file = mlir::openInputFile(input_filename, &error_message); auto file = mlir::openInputFile(input_filename, &error_message);
assert(file);
auto output = mlir::openOutputFile(output_filename, &error_message); auto output = mlir::openOutputFile(output_filename, &error_message);
assert(output);
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline, return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
split_input_file, verify_diagnostics, verify_passes)); split_input_file, verify_diagnostics, verify_passes));

View File

@ -53,7 +53,7 @@ using namespace onnf;
void LoadMLIR(string inputFilename, mlir::MLIRContext& context, void LoadMLIR(string inputFilename, mlir::MLIRContext& context,
mlir::OwningModuleRef& module) { mlir::OwningModuleRef& module) {
// Handle '.mlir' input to the DLC compiler. // Handle '.mlir' input to the ONNF frontend.
// The mlir format indicates that one or more of the supported // The mlir format indicates that one or more of the supported
// representations are used in the file. // representations are used in the file.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =

4
test/CMakeLists.txt Normal file
View File

@ -0,0 +1,4 @@
add_subdirectory(models)
add_subdirectory(nodes)
add_subdirectory(mlir)

21
test/mlir/CMakeLists.txt Normal file
View File

@ -0,0 +1,21 @@
set(LLVM_LIT ${LLVM_SRC}/utils/lit/lit.py)
set(LLVM_DEFAULT_EXTERNAL_LIT ${LLVM_BUILD}/bin/llvm-lit)
configure_lit_site_cfg(${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
MAIN_CONFIG
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py)
set(ONNF_MLIR_TEST_DEPENDS onnf-opt)
add_lit_testsuite(check-mlir-lit
"Running the ONNF MLIR regression tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS
${ONNF_MLIR_TEST_DEPENDS})
set_target_properties(check-mlir-lit PROPERTIES FOLDER "Tests")
add_lit_testsuites(ONNF_MLIR
${CMAKE_CURRENT_SOURCE_DIR}
DEPENDS
${ONNF_MLIR_TEST_DEPS})

29
test/mlir/lit.cfg.py Normal file
View File

@ -0,0 +1,29 @@
import lit.formats
from lit.llvm import llvm_config
from lit.llvm.subst import ToolSubst
# name: The name of this test suite.
config.name = 'Open Neural Network Frontend'
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# test_source_root: The root path where tests are located.
config.test_source_root = config.onnf_mlir_test_src_dir
# test_exec_root: The root path where tests should be run.
config.test_exec_root = config.onnf_mlir_test_build_dir
llvm_config.use_default_substitutions()
# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
tool_dirs = [
config.onnf_mlir_tools_dir, config.mlir_tools_dir, config.llvm_tools_dir
]
tool_names = [
'onnf-opt', 'mlir-opt', 'mlir-translate'
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -0,0 +1,16 @@
import lit.llvm
config.llvm_tools_dir = "@MLIR_TOOLS_DIR@"
config.mlir_obj_root = "@MLIR_BUILD_DIR@"
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
config.suffixes = ['.mlir']
config.onnf_mlir_tools_dir = "@ONNF_TOOLS_DIR@"
config.onnf_mlir_test_src_dir = "@ONNF_LIT_TEST_SRC_DIR@"
config.onnf_mlir_test_build_dir = "@ONNF_LIT_TEST_BUILD_DIR@"
lit.llvm.initialize(lit_config, config)
# Let the main config do the real work.
lit_config.load_config(config, "@ONNF_LIT_TEST_SRC_DIR@/lit.cfg.py")

View File

@ -0,0 +1,14 @@
// RUN: onnf-opt --canonicalize %s -split-input-file | FileCheck %s
//CHECK: module {
module {
func @test_sigmoid() {
%0 = "frontend.input t1"() : () -> tensor<10x10xf32>
%1 = "frontend.input t2"() : () -> tensor<10x10xf32>
%2 = "frontend.input t3"() : () -> tensor<10x10xf32>
// CHECK: %{{[0-9]+}} = "onnx.full_gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%3 = "onnx.matmul"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%4 = "onnx.add"(%3, %2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%5 = "frontend.output t4"(%4) : (tensor<10x10xf32>) -> tensor<10x10xf32>
}
}