Use TableGen (#347)

* use table gen

* fix name of the dialect

* add old compilation path

* add some doc

* fix bug, sgir importer imports every op twice

* knl.visit -> knl.iterate
This commit is contained in:
Tian Jin 2019-10-08 19:25:59 -04:00 committed by Doru Bercea
parent cc39a92802
commit 00aee2e0b6
10 changed files with 232 additions and 86 deletions

View File

@ -35,6 +35,9 @@ if(Boost_FOUND)
include_directories(${Boost_INCLUDE_DIRS}) include_directories(${Boost_INCLUDE_DIRS})
endif() endif()
include(MLIR.cmake)
add_subdirectory(src/builder) add_subdirectory(src/builder)
add_subdirectory(src/compiler)
add_subdirectory(src) add_subdirectory(src)

107
MLIR.cmake Normal file
View File

@ -0,0 +1,107 @@
# Flags to link with LLVM/MLIR libraries
if(DEFINED ENV{LLVM_PROJECT_ROOT})
set(LLVM_PROJECT_ROOT $ENV{LLVM_PROJECT_ROOT})
if(EXISTS ${LLVM_PROJECT_ROOT})
message(STATUS "LLVM_PROJECT_ROOT " ${LLVM_PROJECT_ROOT})
else()
message(
FATAL_ERROR "The path specified by LLVM_PROJECT_ROOT does not exist: "
${LLVM_PROJECT_ROOT})
endif()
else()
message(FATAL_ERROR "env variable LLVM_PROJECT_ROOT not set")
endif()
if(DEFINED ENV{LLVM_PROJECT_LIB})
set(LLVM_PROJECT_LIB $ENV{LLVM_PROJECT_LIB})
else()
set(LLVM_PROJECT_LIB $ENV{LLVM_PROJECT_ROOT}/build/lib)
endif()
if(EXISTS ${LLVM_PROJECT_LIB})
message(STATUS "LLVM_PROJECT_LIB " ${LLVM_PROJECT_LIB})
else()
message(FATAL_ERROR "The path specified by LLVM_PROJECT_LIB does not exist: "
${LLVM_PROJECT_LIB})
endif()
# include path
set(LLVM_SRC_INCLUDE_PATH ${LLVM_PROJECT_ROOT}/llvm/include)
set(LLVM_BIN_INCLUDE_PATH ${LLVM_PROJECT_ROOT}/build/include)
set(MLIR_SRC_INCLUDE_PATH ${LLVM_PROJECT_ROOT}/llvm/projects/mlir/include)
set(MLIR_BIN_INCLUDE_PATH ${LLVM_PROJECT_ROOT}/build/projects/mlir/include)
set(MLIR_INCLUDE_PATHS
${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH})
include_directories(${MLIR_INCLUDE_PATHS})
find_library(MLIRLIBANALYSIS
NAMES MLIRAnalysis
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIRLIBIR NAMES MLIRIR PATHS ${LLVM_PROJECT_LIB} NO_DEFAULT_PATH)
find_library(MLIRLIBPARSER
NAMES MLIRParser
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIRLIBTRANSFORMS
NAMES MLIRTransforms
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIRLIBVECTOROPS
NAMES MLIRVectorOps
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIRLIBSUPPORT
NAMES MLIRSupport
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIRLIBSTANDARDOPS
NAMES MLIRStandardOps
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(LLVMLIBSUPPORT
NAMES LLVMSupport
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
# libraries are set according to toy/Ch2
set(MLIRLIBS
${MLIRLIBANALYSIS}
${MLIRLIBIR}
${MLIRLIBPARSER}
${MLIRLIBTRANSFORMS}
${MLIRLIBANALYSIS}
${MLIRLIBVECTOROPS}
${MLIRLIBIR}
${MLIRLIBSUPPORT}
${MLIRLIBSTANDARDOPS}
${LLVMLIBSUPPORT})
# Set up TableGen environment.
include(${LLVM_PROJECT_ROOT}/build/lib/cmake/llvm/TableGen.cmake)
function(onnf_tablegen ofn)
tablegen(MLIR
${ARGV}
"-I${MLIR_SRC_INCLUDE_PATH}"
"-I${MLIR_BIN_INCLUDE_PATH}")
set(TABLEGEN_OUTPUT
${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
PARENT_SCOPE)
endfunction()
# Import the pre-built mlir TableGen as an imported exetuable. It is required by
# the LLVM TableGen command to have the TableGen target so that changes to the
# table gen utility itself can be detected and cause re-compilation of .td file.
add_executable(mlir-tblgen IMPORTED)
set_property(TARGET mlir-tblgen
PROPERTY IMPORTED_LOCATION
${LLVM_PROJECT_ROOT}/build/bin/mlir-tblgen)
set(MLIR_TABLEGEN_EXE mlir-tblgen)

View File

@ -1,84 +1,7 @@
add_definitions(-DBOOST_LOG_DYN_LINK)
add_library(builder add_library(builder
sgir.cpp sgir.cpp
) )
#Flags to link with LLVM/MLIR libraries
if(DEFINED ENV{LLVM_PROJECT_ROOT})
set(LLVM_PROJECT_ROOT $ENV{LLVM_PROJECT_ROOT})
if (EXISTS ${LLVM_PROJECT_ROOT})
message(STATUS "LLVM_PROJECT_ROOT " ${LLVM_PROJECT_ROOT})
else()
message(FATAL_ERROR "The path specified by LLVM_PROJECT_ROOT does not exist: " ${LLVM_PROJECT_ROOT})
endif()
else()
message(FATAL_ERROR "env variable LLVM_PROJECT_ROOT not set")
endif()
if(DEFINED ENV{LLVM_PROJECT_LIB})
set(LLVM_PROJECT_LIB $ENV{LLVM_PROJECT_LIB})
else()
set(LLVM_PROJECT_LIB $ENV{LLVM_PROJECT_ROOT}/build/lib)
endif()
if (EXISTS ${LLVM_PROJECT_LIB})
message(STATUS "LLVM_PROJECT_LIB " ${LLVM_PROJECT_LIB})
else()
message(FATAL_ERROR "The path specified by LLVM_PROJECT_LIB does not exist: " ${LLVM_PROJECT_LIB})
endif()
#include path
include_directories(${LLVM_PROJECT_ROOT}/llvm/projects/mlir/include)
include_directories(${LLVM_PROJECT_ROOT}/llvm/include)
include_directories(${LLVM_PROJECT_ROOT}/build/include)
include_directories(${LLVM_PROJECT_ROOT}/build/projects/mlir/include)
find_library(MLIRLIBANALYSIS
NAMES MLIRAnalysis
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIRLIBIR
NAMES MLIRIR
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIRLIBPARSER
NAMES MLIRParser
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIRLIBTRANSFORMS
NAMES MLIRTransforms
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIRLIBVECTOROPS
NAMES MLIRVectorOps
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIRLIBSUPPORT
NAMES MLIRSupport
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIRLIBSTANDARDOPS
NAMES MLIRStandardOps
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(LLVMLIBSUPPORT
NAMES LLVMSupport
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
#libraries are set according to toy/Ch2
set(MLIRLIBS ${MLIRLIBANALYSIS} ${MLIRLIBIR} ${MLIRLIBPARSER} ${MLIRLIBTRANSFORMS}
${MLIRLIBANALYSIS} ${MLIRLIBVECTOROPS} ${MLIRLIBIR} ${MLIRLIBSUPPORT} ${MLIRLIBSTANDARDOPS}
${LLVMLIBSUPPORT})
target_link_libraries(builder onnx ${MLIRLIBS} curses) target_link_libraries(builder onnx ${MLIRLIBS} curses)
target_include_directories(builder target_include_directories(builder
PRIVATE PRIVATE

View File

@ -157,7 +157,7 @@ class SGIRGenImpl {
result.addOperands(inputs); result.addOperands(inputs);
auto op = builder_.createOperation(result); auto op = builder_.createOperation(result);
for (int i = 0; i < node.output().size(); i++) { for (int i = 0; i < node.output().size(); i++) {
auto r = builder_.createOperation(result)->getResult(i); auto r = op->getResult(i);
sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r); sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r);
} }
@ -212,14 +212,10 @@ class SGIRGenImpl {
}; // SGIRGenImpl class }; // SGIRGenImpl class
} // namespace } // namespace
} // namespace dlc } // namespace onnf
namespace onnf { namespace onnf {
/*!
* Generate SGIR with MLIR for a onnx model
* @param model onnx model.
* @return module mlir module generated for the onnx model
*/
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model) { mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model) {
mlir::MLIRContext context; mlir::MLIRContext context;
SGIRGenImpl mySGIRGen(context); SGIRGenImpl mySGIRGen(context);
@ -229,4 +225,11 @@ mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model) {
return module; return module;
} }
mlir::OwningModuleRef SGIRImportModelFile(std::string model_fname) {
onnx::ModelProto model;
std::fstream input(model_fname, std::ios::in | std::ios::binary);
auto parse_success = model.ParseFromIstream(&input);
return SGIRImportModel(model);
}
} // namespace onnf } // namespace onnf

View File

@ -25,10 +25,16 @@ class OwningModuleRef;
namespace onnf { namespace onnf {
/*! /*!
* Import an ONNX Model into SGIR * Import an ONNX Model into SGIR.
* @param model onnx model. * @param model onnx model.
* @return MLIR::module generated for the ONNX model * @return MLIR::module generated for the ONNX model.
*/ */
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model); mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model);
/*!
* Import an ONNX Model file into SGIR.
* @param model_fname file name pointing to the onnx model protobuf.
* @return MLIR::module generated for the ONNX model.
*/
mlir::OwningModuleRef SGIRImportModelFile(std::string model_fname);
} // namespace onnf } // namespace onnf

View File

@ -0,0 +1,28 @@
add_library(
compiler
ir/knl/knl_ops.cpp
ir/knl/knl_ops.hpp)
# Include root src directory.
target_include_directories(compiler PRIVATE ../..)
target_include_directories(compiler PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
find_package(Boost 1.54.0
COMPONENTS graph
program_options
log_setup
log
system
filesystem
REQUIRED)
# target_link_libraries(compiler isl inja ${Boost_LIBRARIES})
target_link_libraries(compiler
${Boost_LIBRARIES}
)
set(LLVM_TARGET_DEFINITIONS ir/knl/knl.td)
onnf_tablegen(knl.hpp.inc -gen-op-decls)
onnf_tablegen(knl.cpp.inc -gen-op-defs)
add_public_tablegen_target(gen_kir)
add_dependencies(compiler gen_kir)

View File

@ -0,0 +1,27 @@
include "mlir/IR/OpBase.td"
def Knl_Dialect : Dialect {
let name = "knl";
let cppNamespace = "";
}
def KnlIterate : Op<Knl_Dialect, "iterate"> {
let summary = "iterate operation";
let description = [{
The "knl.iterate" operation is conceptually equivalent to a nested for loop
in that it represents ordered interation of integer coordinates within an
affine integer set.
}];
let arguments = (ins Variadic<AnyType>);
let regions = (region SizedRegion<1>:$region);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState &result, "
"IntegerSet set, ArrayRef<Value *> args">
];
}

View File

@ -0,0 +1,23 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "knl_ops.hpp"
namespace mlir {
KnlOpsDialect::KnlOpsDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "knl.cpp.inc"
>();
}
} // namespace mlir
namespace onnf {}

View File

@ -0,0 +1,19 @@
#pragma once
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
class KnlOpsDialect : public Dialect {
public:
KnlOpsDialect(MLIRContext* context);
static StringRef getDialectNamespace() { return "knl"; }
};
#define GET_OP_CLASSES
#include "knl.hpp.inc"
} // namespace mlir
namespace onnf {}

View File

@ -20,6 +20,10 @@
#include <boost/program_options.hpp> #include <boost/program_options.hpp>
#include "src/builder/sgir.hpp"
#include "mlir/IR/Module.h"
using namespace std; using namespace std;
int main(int ac, char* av[]) { int main(int ac, char* av[]) {
@ -40,5 +44,8 @@ int main(int ac, char* av[]) {
return 0; return 0;
} }
string model_filename = vm["onnx-model"].as<string>();
auto module = SGIRImportModelFile(model_filename);
return 0; return 0;
} }