[MLIR] Refactor Krnl Dialect and Krnl Dialect Lowering (#375)
* Store bounds as affine map attributes & check in test cases with generic printer * Upgrading MLIR MLIR is outdated on Buildbot, rebuilding a newer version. * work with new version of mlir * check-in parser tests * custom printer * nit * bug fix * enable custom asm printer test * enable custom asm printer test * more consistent variable naming * test max/min * variable naming scheme change to MLIR style * can lower krnl to llvm * kernel -> llvm * comments * bug fix * try fixing ci * fix ci * deactivate model test * fix lit test * nit * fix z buildbot
This commit is contained in:
parent
652ce4b7d4
commit
b2a1103915
|
@ -18,6 +18,8 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
||||||
#TODO(eventually enable the following)
|
#TODO(eventually enable the following)
|
||||||
#set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
#set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||||
|
|
||||||
|
include(MLIR.cmake)
|
||||||
|
|
||||||
add_subdirectory(third_party/onnx)
|
add_subdirectory(third_party/onnx)
|
||||||
add_subdirectory(third_party/benchmark)
|
add_subdirectory(third_party/benchmark)
|
||||||
add_subdirectory(third_party/pybind11)
|
add_subdirectory(third_party/pybind11)
|
||||||
|
@ -41,7 +43,6 @@ if(Boost_FOUND)
|
||||||
include_directories(${Boost_INCLUDE_DIRS})
|
include_directories(${Boost_INCLUDE_DIRS})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
include(MLIR.cmake)
|
|
||||||
add_subdirectory(src/builder)
|
add_subdirectory(src/builder)
|
||||||
add_subdirectory(src/compiler)
|
add_subdirectory(src/compiler)
|
||||||
add_subdirectory(src)
|
add_subdirectory(src)
|
||||||
|
|
92
MLIR.cmake
92
MLIR.cmake
|
@ -5,7 +5,7 @@ if(DEFINED ENV{LLVM_SRC})
|
||||||
message(STATUS "LLVM_SRC " ${LLVM_SRC})
|
message(STATUS "LLVM_SRC " ${LLVM_SRC})
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: "
|
message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: "
|
||||||
${LLVM_SRC})
|
${LLVM_SRC})
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "env variable LLVM_SRC not set")
|
message(FATAL_ERROR "env variable LLVM_SRC not set")
|
||||||
|
@ -18,7 +18,7 @@ if(DEFINED ENV{LLVM_BUILD})
|
||||||
message(STATUS "LLVM_BUILD " ${LLVM_BUILD})
|
message(STATUS "LLVM_BUILD " ${LLVM_BUILD})
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: "
|
message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: "
|
||||||
${LLVM_BUILD})
|
${LLVM_BUILD})
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "env variable LLVM_BUILD not set")
|
message(FATAL_ERROR "env variable LLVM_BUILD not set")
|
||||||
|
@ -39,9 +39,9 @@ set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
|
||||||
set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir)
|
set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir)
|
||||||
|
|
||||||
set(
|
set(
|
||||||
MLIR_INCLUDE_PATHS
|
MLIR_INCLUDE_PATHS
|
||||||
${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH}
|
${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH}
|
||||||
)
|
)
|
||||||
include_directories(${MLIR_INCLUDE_PATHS})
|
include_directories(${MLIR_INCLUDE_PATHS})
|
||||||
|
|
||||||
# Threading libraries required due to parallel pass execution.
|
# Threading libraries required due to parallel pass execution.
|
||||||
|
@ -49,9 +49,9 @@ find_package(Threads REQUIRED)
|
||||||
|
|
||||||
function(find_mlir_lib lib)
|
function(find_mlir_lib lib)
|
||||||
find_library(${lib}
|
find_library(${lib}
|
||||||
NAMES ${lib}
|
NAMES ${lib}
|
||||||
PATHS ${LLVM_PROJECT_LIB}
|
PATHS ${LLVM_PROJECT_LIB}
|
||||||
NO_DEFAULT_PATH)
|
NO_DEFAULT_PATH)
|
||||||
endfunction(find_mlir_lib)
|
endfunction(find_mlir_lib)
|
||||||
|
|
||||||
find_mlir_lib(MLIRAffineOps)
|
find_mlir_lib(MLIRAffineOps)
|
||||||
|
@ -70,6 +70,10 @@ find_mlir_lib(MLIRTransforms)
|
||||||
find_mlir_lib(MLIRTransformUtils)
|
find_mlir_lib(MLIRTransformUtils)
|
||||||
find_mlir_lib(MLIRSupport)
|
find_mlir_lib(MLIRSupport)
|
||||||
find_mlir_lib(MLIROptMain)
|
find_mlir_lib(MLIROptMain)
|
||||||
|
find_mlir_lib(MLIRTargetLLVMIRModuleTranslation)
|
||||||
|
find_mlir_lib(MLIRTargetLLVMIR)
|
||||||
|
find_mlir_lib(MLIRTransformUtils)
|
||||||
|
find_mlir_lib(MLIRTranslation)
|
||||||
find_mlir_lib(MLIRVectorOps)
|
find_mlir_lib(MLIRVectorOps)
|
||||||
|
|
||||||
find_mlir_lib(LLVMCore)
|
find_mlir_lib(LLVMCore)
|
||||||
|
@ -80,46 +84,52 @@ find_mlir_lib(LLVMRemarks)
|
||||||
find_mlir_lib(LLVMIRReader)
|
find_mlir_lib(LLVMIRReader)
|
||||||
find_mlir_lib(LLVMTransformUtils)
|
find_mlir_lib(LLVMTransformUtils)
|
||||||
find_mlir_lib(LLVMBitstreamReader)
|
find_mlir_lib(LLVMBitstreamReader)
|
||||||
|
find_mlir_lib(LLVMAnalysis)
|
||||||
|
find_mlir_lib(LLVMBitWriter)
|
||||||
|
find_mlir_lib(LLVMBitReader)
|
||||||
|
find_mlir_lib(LLVMMC)
|
||||||
|
find_mlir_lib(LLVMMCParser)
|
||||||
|
find_mlir_lib(LLVMObject)
|
||||||
|
find_mlir_lib(LLVMProfileData)
|
||||||
|
find_mlir_lib(LLVMDemangle)
|
||||||
|
|
||||||
|
|
||||||
set(MLIRLibsOnce
|
set(MLIRLibsOnce
|
||||||
|
LLVMAnalysis
|
||||||
|
LLVMAsmParser
|
||||||
|
LLVMBinaryFormat
|
||||||
|
LLVMBitReader
|
||||||
|
LLVMBitstreamReader
|
||||||
|
LLVMBitWriter
|
||||||
|
LLVMCore
|
||||||
|
LLVMIRReader
|
||||||
|
LLVMMC
|
||||||
|
LLVMMCParser
|
||||||
|
LLVMObject
|
||||||
|
LLVMRemarks
|
||||||
|
LLVMSupport
|
||||||
|
LLVMTransformUtils
|
||||||
|
LLVMProfileData
|
||||||
|
LLVMDemangle
|
||||||
MLIRAffineOps
|
MLIRAffineOps
|
||||||
MLIRAffineToStandard
|
MLIRAffineToStandard
|
||||||
MLIRAnalysis
|
MLIRAnalysis
|
||||||
MLIRExecutionEngine
|
MLIRExecutionEngine
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRLLVMIR
|
MLIRLLVMIR
|
||||||
|
MLIRLoopOps
|
||||||
MLIRLoopToStandard
|
MLIRLoopToStandard
|
||||||
|
MLIROptMain
|
||||||
MLIRParser
|
MLIRParser
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRStandardOps
|
MLIRStandardOps
|
||||||
MLIRStandardToLLVM
|
MLIRStandardToLLVM
|
||||||
|
MLIRSupport
|
||||||
MLIRTargetLLVMIR
|
MLIRTargetLLVMIR
|
||||||
MLIRTransforms
|
MLIRTargetLLVMIRModuleTranslation
|
||||||
MLIRAffineOps
|
|
||||||
MLIRAffineToStandard
|
|
||||||
MLIRAnalysis
|
|
||||||
MLIRExecutionEngine
|
|
||||||
MLIRIR
|
|
||||||
MLIRLLVMIR
|
|
||||||
MLIRLoopToStandard
|
|
||||||
MLIRParser
|
|
||||||
MLIRPass
|
|
||||||
MLIRStandardOps
|
|
||||||
MLIRStandardToLLVM
|
|
||||||
MLIRTargetLLVMIR
|
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
MLIRTransformUtils
|
MLIRTransformUtils
|
||||||
MLIRLoopOps
|
MLIRTranslation)
|
||||||
MLIRSupport
|
|
||||||
MLIROptMain
|
|
||||||
LLVMCore
|
|
||||||
LLVMSupport
|
|
||||||
LLVMAsmParser
|
|
||||||
LLVMIRReader
|
|
||||||
LLVMTransformUtils
|
|
||||||
LLVMBinaryFormat
|
|
||||||
LLVMRemarks
|
|
||||||
LLVMBitstreamReader)
|
|
||||||
|
|
||||||
set(MLIRLibs
|
set(MLIRLibs
|
||||||
${MLIRLibsOnce}
|
${MLIRLibsOnce}
|
||||||
|
@ -142,7 +152,7 @@ function(whole_archive_link target lib_dir)
|
||||||
set(link_flags "${link_flags} -L${lib_dir} ")
|
set(link_flags "${link_flags} -L${lib_dir} ")
|
||||||
foreach(LIB ${ARGN})
|
foreach(LIB ${ARGN})
|
||||||
string(CONCAT link_flags ${link_flags}
|
string(CONCAT link_flags ${link_flags}
|
||||||
"-Wl,-force_load ${lib_dir}/lib${LIB}.a ")
|
"-Wl,-force_load ${lib_dir}/lib${LIB}.a ")
|
||||||
endforeach(LIB)
|
endforeach(LIB)
|
||||||
elseif(MSVC)
|
elseif(MSVC)
|
||||||
foreach(LIB ${ARGN})
|
foreach(LIB ${ARGN})
|
||||||
|
@ -170,20 +180,20 @@ function(whole_archive_link_onnf target)
|
||||||
endfunction(whole_archive_link_onnf)
|
endfunction(whole_archive_link_onnf)
|
||||||
|
|
||||||
set(LLVM_CMAKE_DIR
|
set(LLVM_CMAKE_DIR
|
||||||
"${LLVM_BUILD}/lib/cmake/llvm"
|
"${LLVM_BUILD}/lib/cmake/llvm"
|
||||||
CACHE PATH "Path to LLVM cmake modules")
|
CACHE PATH "Path to LLVM cmake modules")
|
||||||
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
|
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
|
||||||
include(AddLLVM)
|
include(AddLLVM)
|
||||||
include(TableGen)
|
include(TableGen)
|
||||||
|
|
||||||
function(onnf_tablegen ofn)
|
function(onnf_tablegen ofn)
|
||||||
tablegen(MLIR
|
tablegen(MLIR
|
||||||
${ARGV}
|
${ARGV}
|
||||||
"-I${MLIR_SRC_INCLUDE_PATH}"
|
"-I${MLIR_SRC_INCLUDE_PATH}"
|
||||||
"-I${MLIR_BIN_INCLUDE_PATH}")
|
"-I${MLIR_BIN_INCLUDE_PATH}")
|
||||||
set(TABLEGEN_OUTPUT
|
set(TABLEGEN_OUTPUT
|
||||||
${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
|
${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
|
||||||
PARENT_SCOPE)
|
PARENT_SCOPE)
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
# Import the pre-built mlir TableGen as an imported exetuable. It is required by
|
# Import the pre-built mlir TableGen as an imported exetuable. It is required by
|
||||||
|
@ -191,5 +201,5 @@ endfunction()
|
||||||
# table gen utility itself can be detected and cause re-compilation of .td file.
|
# table gen utility itself can be detected and cause re-compilation of .td file.
|
||||||
add_executable(mlir-tblgen IMPORTED)
|
add_executable(mlir-tblgen IMPORTED)
|
||||||
set_property(TARGET mlir-tblgen
|
set_property(TARGET mlir-tblgen
|
||||||
PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen)
|
PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen)
|
||||||
set(MLIR_TABLEGEN_EXE mlir-tblgen)
|
set(MLIR_TABLEGEN_EXE mlir-tblgen)
|
||||||
|
|
|
@ -2,6 +2,7 @@ add_executable(onnf main.cpp)
|
||||||
|
|
||||||
target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES})
|
target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES})
|
||||||
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
||||||
|
whole_archive_link_onnf(onnf onnf_transform)
|
||||||
|
|
||||||
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
||||||
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
||||||
|
|
|
@ -6,8 +6,8 @@ add_library(
|
||||||
dialect/krnl/krnl_types.hpp
|
dialect/krnl/krnl_types.hpp
|
||||||
dialect/onnx/onnx_ops.cpp
|
dialect/onnx/onnx_ops.cpp
|
||||||
dialect/onnx/onnx_ops.hpp
|
dialect/onnx/onnx_ops.hpp
|
||||||
dialect/krnl/parser_helper.cpp
|
dialect/krnl/krnl_helper.cpp
|
||||||
dialect/krnl/parser_helper.hpp
|
dialect/krnl/krnl_helper.hpp
|
||||||
pass/shape_inference_pass.cpp
|
pass/shape_inference_pass.cpp
|
||||||
pass/shape_inference_interface.hpp
|
pass/shape_inference_interface.hpp
|
||||||
dialect/onnx/onnxop.inc
|
dialect/onnx/onnxop.inc
|
||||||
|
@ -82,4 +82,5 @@ target_include_directories(onnf_lower_frontend
|
||||||
target_link_libraries(onnf_lower_frontend ${MLIRLibs})
|
target_link_libraries(onnf_lower_frontend ${MLIRLibs})
|
||||||
add_dependencies(onnf_lower_frontend gen_krnl_ops)
|
add_dependencies(onnf_lower_frontend gen_krnl_ops)
|
||||||
|
|
||||||
|
add_subdirectory(transform)
|
||||||
add_subdirectory(tool)
|
add_subdirectory(tool)
|
||||||
|
|
|
@ -0,0 +1,139 @@
|
||||||
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
|
#include "mlir/IR/AffineExpr.h"
|
||||||
|
|
||||||
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||||
|
|
||||||
|
#include "krnl_helper.hpp"
|
||||||
|
|
||||||
|
namespace onnf {
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||||
|
const Type& operandType, Value*& operand) {
|
||||||
|
// If operand queue is empty, parse more operands and cache them.
|
||||||
|
if (_operandRefQueue.empty()) {
|
||||||
|
// Parse operand types:
|
||||||
|
llvm::SmallVector<OpAsmParser::OperandType, 2> operand_refs;
|
||||||
|
_parser.parseOperandList(operand_refs);
|
||||||
|
|
||||||
|
// Record operands:
|
||||||
|
for (auto& operand_ref : operand_refs)
|
||||||
|
_operandRefQueue.emplace(operand_ref);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we parsed some operand reference(s), resolve the ref to an operand:
|
||||||
|
if (!_operandRefQueue.empty()) {
|
||||||
|
auto operand_ref = _operandRefQueue.front();
|
||||||
|
_operandRefQueue.pop();
|
||||||
|
|
||||||
|
llvm::SmallVector<Value*, 1> operands;
|
||||||
|
_parser.resolveOperand(operand_ref, operandType, operands);
|
||||||
|
operand = operands.front();
|
||||||
|
return success();
|
||||||
|
} else {
|
||||||
|
operand = nullptr;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||||
|
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
|
||||||
|
Value* operand = nullptr;
|
||||||
|
if (ParseOptionalOperand(operandType, operand))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
operandList.emplace_back(operand);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||||
|
const Type& operandType, Value*& operand) {
|
||||||
|
if (ParseOptionalOperand(operandType, operand))
|
||||||
|
return _parser.emitError(
|
||||||
|
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||||
|
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
|
||||||
|
if (ParseOptionalOperand(operandType, operandList))
|
||||||
|
return _parser.emitError(
|
||||||
|
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
|
||||||
|
unsigned numSymbols, OpAsmPrinter& p) {
|
||||||
|
p << '(';
|
||||||
|
p.printOperands(begin, begin + numDims);
|
||||||
|
p << ')';
|
||||||
|
|
||||||
|
if (numSymbols) {
|
||||||
|
p << '[';
|
||||||
|
p.printOperands(begin + numDims, begin + numDims + numSymbols);
|
||||||
|
p << ']';
|
||||||
|
}
|
||||||
|
|
||||||
|
begin = std::next(begin, numDims + numSymbols);
|
||||||
|
}
|
||||||
|
|
||||||
|
void printBound(AffineMapAttr boundMap,
|
||||||
|
Operation::operand_iterator& boundOperandsBeg, const char* prefix,
|
||||||
|
OpAsmPrinter& p) {
|
||||||
|
AffineMap map = boundMap.getValue();
|
||||||
|
|
||||||
|
// Check if this bound should be printed using custom assembly form.
|
||||||
|
// The decision to restrict printing custom assembly form to trivial cases
|
||||||
|
// comes from the will to roundtrip MLIR binary -> text -> binary in a
|
||||||
|
// lossless way.
|
||||||
|
// Therefore, custom assembly form parsing and printing is only supported for
|
||||||
|
// zero-operand constant maps and single symbol operand identity maps.
|
||||||
|
if (map.getNumResults() == 1) {
|
||||||
|
AffineExpr expr = map.getResult(0);
|
||||||
|
|
||||||
|
// Print constant bound.
|
||||||
|
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
|
||||||
|
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
|
||||||
|
p << constExpr.getValue();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print bound that consists of a single SSA symbol if the map is over a
|
||||||
|
// single symbol.
|
||||||
|
if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
|
||||||
|
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
|
||||||
|
p.printOperand(*(boundOperandsBeg++));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Map has multiple results. Print 'min' or 'max' prefix.
|
||||||
|
p << prefix << ' ';
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the map and its operands.
|
||||||
|
p << boundMap;
|
||||||
|
printDimAndSymbolList(
|
||||||
|
boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p);
|
||||||
|
}
|
||||||
|
} // namespace onnf
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
|
||||||
|
if (boundMaps.size() % 2 == 0)
|
||||||
|
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
||||||
|
AffineMap map = builder.getConstantAffineMap(bound);
|
||||||
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
|
}
|
||||||
|
|
||||||
|
void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) {
|
||||||
|
if (boundMaps.size() % 2 == 0)
|
||||||
|
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
||||||
|
AffineMap map = builder.getSymbolIdentityMap();
|
||||||
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
|
_operands.emplace_back(operand);
|
||||||
|
}
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,102 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
namespace onnf {
|
||||||
|
|
||||||
|
class KrnlDialectOperandParser {
|
||||||
|
public:
|
||||||
|
explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser)
|
||||||
|
: _parser(parser), _builder(parser.getBuilder()){};
|
||||||
|
|
||||||
|
// Parse an optional operand.
|
||||||
|
mlir::ParseResult ParseOptionalOperand(
|
||||||
|
const mlir::Type& operandType, mlir::Value*& operand);
|
||||||
|
|
||||||
|
// Parse an optional operand and push it to an operand list.
|
||||||
|
mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType,
|
||||||
|
llvm::SmallVectorImpl<mlir::Value*>& operandList);
|
||||||
|
|
||||||
|
// Parse a required operand.
|
||||||
|
mlir::ParseResult ParseOperand(
|
||||||
|
const mlir::Type& operandType, mlir::Value*& operand);
|
||||||
|
|
||||||
|
// Parse a required operand and push it to an operand list.
|
||||||
|
mlir::ParseResult ParseOperand(const mlir::Type& operandType,
|
||||||
|
llvm::SmallVectorImpl<mlir::Value*>& operandList);
|
||||||
|
|
||||||
|
// Do we have more operands to parse?
|
||||||
|
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
mlir::OpAsmParser& _parser;
|
||||||
|
|
||||||
|
mlir::Builder& _builder;
|
||||||
|
|
||||||
|
// A queue storing the parsed SSA id references.
|
||||||
|
std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Adapted from:
|
||||||
|
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
|
||||||
|
// Main difference is that it advances the iterator `begin` as it consumes
|
||||||
|
// dimension and symbol operands.
|
||||||
|
void printDimAndSymbolList(mlir::Operation::operand_iterator& begin,
|
||||||
|
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p);
|
||||||
|
|
||||||
|
// Adapted from:
|
||||||
|
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
|
||||||
|
// Main difference is that it advances the iterator `boundOperandsBeg` as it
|
||||||
|
// prints bound.
|
||||||
|
void printBound(mlir::AffineMapAttr boundMap,
|
||||||
|
mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix,
|
||||||
|
mlir::OpAsmPrinter& p);
|
||||||
|
} // namespace onnf
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
struct KrnlIterateOperandPack {
|
||||||
|
KrnlIterateOperandPack(mlir::Builder& builder,
|
||||||
|
llvm::ArrayRef<mlir::Value*> inputLoops,
|
||||||
|
llvm::ArrayRef<mlir::Value*> optimizedLoops)
|
||||||
|
: builder(builder),
|
||||||
|
inputLoops(inputLoops),
|
||||||
|
optimizedLoops(optimizedLoops) {
|
||||||
|
_operands.insert(
|
||||||
|
_operands.end(), optimizedLoops.begin(), optimizedLoops.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
void pushConstantBound(int64_t bound);
|
||||||
|
|
||||||
|
void pushOperandBound(mlir::Value* operand);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value*, 8> getOperands() const { return _operands; }
|
||||||
|
|
||||||
|
mlir::ArrayAttr getAttributes() const {
|
||||||
|
return builder.getArrayAttr(boundMaps);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getNumOptimizedLoops() const { return optimizedLoops.size(); }
|
||||||
|
|
||||||
|
size_t getNumInputLoops() const { return inputLoops.size(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int _boundIdx = 0;
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value*, 8> _operands;
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Attribute, 8> boundMaps;
|
||||||
|
|
||||||
|
llvm::ArrayRef<mlir::Value*> inputLoops, optimizedLoops;
|
||||||
|
|
||||||
|
mlir::Builder& builder;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir
|
|
@ -9,10 +9,10 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
|
||||||
#include "src/compiler/dialect/krnl/parser_helper.hpp"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/IR/Block.h"
|
#include "mlir/IR/Block.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
|
@ -24,6 +24,8 @@
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
|
||||||
|
|
||||||
#include "krnl_ops.hpp"
|
#include "krnl_ops.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -52,26 +54,25 @@ void KrnlDefineLoopsOp::build(
|
||||||
}
|
}
|
||||||
|
|
||||||
void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) {
|
void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) {
|
||||||
auto num_loop_attr = op.getAttrOfType<IntegerAttr>(op.getNumLoopsAttrName());
|
auto numLoopAttr =
|
||||||
p << "krnl.define_loops " << num_loop_attr.getValue().getSExtValue();
|
op.getAttrOfType<IntegerAttr>(KrnlDefineLoopsOp::getNumLoopsAttrName());
|
||||||
|
p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseKrnlDefineLoopsOp(
|
ParseResult parseKrnlDefineLoopsOp(
|
||||||
OpAsmParser& parser, OperationState& result) {
|
OpAsmParser& parser, OperationState& result) {
|
||||||
// Parse the attribute indicating number of loops defined.
|
// Parse the attribute indicating number of loops defined.
|
||||||
IntegerAttr num_loops;
|
IntegerAttr numLoops;
|
||||||
auto& builder = parser.getBuilder();
|
auto& builder = parser.getBuilder();
|
||||||
auto int32_type = builder.getIntegerType(64);
|
auto intType = builder.getIntegerType(64);
|
||||||
if (parser.parseAttribute(num_loops, int32_type,
|
if (parser.parseAttribute(numLoops, intType,
|
||||||
KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes))
|
KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto loop_types = llvm::SmallVector<Type, 4>(
|
auto loopTypes = llvm::SmallVector<Type, 4>(
|
||||||
num_loops.getValue().getSExtValue(), LoopType::get(builder.getContext()));
|
numLoops.getValue().getSExtValue(), LoopType::get(builder.getContext()));
|
||||||
if (parser.addTypesToList(loop_types, result.types))
|
if (parser.addTypesToList(loopTypes, result.types))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -142,39 +143,14 @@ ParseResult parseKrnlOptimizeLoopsOp(
|
||||||
* %i0 = 10 to N : %i1 = M to 20
|
* %i0 = 10 to N : %i1 = M to 20
|
||||||
*/
|
*/
|
||||||
void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
||||||
ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops,
|
KrnlIterateOperandPack operandPack) {
|
||||||
ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds,
|
|
||||||
ArrayRef<int> bound_types) {
|
|
||||||
// Record optimized loops and the number of such loops.
|
// Record optimized loops and the number of such loops.
|
||||||
result.addOperands(optimized_loops);
|
result.addOperands(operandPack.getOperands());
|
||||||
|
result.addAttribute(
|
||||||
|
KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes());
|
||||||
|
|
||||||
result.addAttribute(getNumOptimizedLoopsAttrName(),
|
result.addAttribute(getNumOptimizedLoopsAttrName(),
|
||||||
builder->getI64IntegerAttr(optimized_loops.size()));
|
builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops()));
|
||||||
|
|
||||||
// Record input loops and the number of such loops.
|
|
||||||
result.addOperands(input_loops);
|
|
||||||
result.addAttribute(getNumInputLoopsAttrName(),
|
|
||||||
builder->getI64IntegerAttr(input_loops.size()));
|
|
||||||
|
|
||||||
// Record bound either as attribute or from operand list.
|
|
||||||
auto next_operand_bound = operand_bounds.begin();
|
|
||||||
auto next_const_bound = const_bounds.begin();
|
|
||||||
for (size_t i = 0; i < bound_types.size(); i++) {
|
|
||||||
auto bound_type = bound_types[i];
|
|
||||||
if (bound_type == 0) {
|
|
||||||
// Constant bound.
|
|
||||||
result.addAttribute(getBoundAttrName(i / 2, i % 2),
|
|
||||||
builder->getI64IntegerAttr(*next_const_bound));
|
|
||||||
next_const_bound = std::next(next_const_bound);
|
|
||||||
} else {
|
|
||||||
// Operand bound.
|
|
||||||
result.addOperands(*next_operand_bound);
|
|
||||||
next_operand_bound = std::next(next_operand_bound);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Record bound types as attribute:
|
|
||||||
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
|
|
||||||
builder->getI32ArrayAttr(bound_types));
|
|
||||||
|
|
||||||
// Create a region and a block for the body. The arguments of the region are
|
// Create a region and a block for the body. The arguments of the region are
|
||||||
// the loop induction variables; there can be multiple induction variables
|
// the loop induction variables; there can be multiple induction variables
|
||||||
|
@ -182,7 +158,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
||||||
Region* bodyRegion = result.addRegion();
|
Region* bodyRegion = result.addRegion();
|
||||||
auto* body = new Block();
|
auto* body = new Block();
|
||||||
auto body_args = llvm::SmallVector<Type, 4>(
|
auto body_args = llvm::SmallVector<Type, 4>(
|
||||||
input_loops.size(), IndexType::get(builder->getContext()));
|
operandPack.getNumInputLoops(), IndexType::get(builder->getContext()));
|
||||||
body->addArguments(body_args);
|
body->addArguments(body_args);
|
||||||
bodyRegion->push_back(body);
|
bodyRegion->push_back(body);
|
||||||
|
|
||||||
|
@ -192,57 +168,31 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
||||||
void print(OpAsmPrinter& p, KrnlIterateOp& op) {
|
void print(OpAsmPrinter& p, KrnlIterateOp& op) {
|
||||||
p << "krnl.iterate(";
|
p << "krnl.iterate(";
|
||||||
// Print optimized loops:
|
// Print optimized loops:
|
||||||
auto num_optimized_loops = op.getNumOptimizedLoops();
|
auto numOptimizedLoops = op.getNumOptimizedLoops();
|
||||||
p.printOperands(op.operand_begin(), op.operand_begin() + num_optimized_loops);
|
p.printOperands(op.operand_begin(), op.operand_begin() + numOptimizedLoops);
|
||||||
p << ") with (";
|
p << ") with (";
|
||||||
|
|
||||||
// Set up iterator to input loops:
|
auto inductionVars = op.bodyRegion().begin()->getArguments();
|
||||||
auto num_input_loops = op.getNumInputLoops();
|
auto boundItr =
|
||||||
auto input_loop_begin = op.operand_begin() + num_optimized_loops;
|
op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
|
||||||
|
.getValue()
|
||||||
|
.begin();
|
||||||
|
auto operandItr = op.operand_begin() + numOptimizedLoops;
|
||||||
|
|
||||||
// Set up iterators to operand bounds.
|
std::string delimiter;
|
||||||
auto next_operand_bound = input_loop_begin + num_input_loops;
|
for (auto& var : inductionVars) {
|
||||||
|
p << delimiter;
|
||||||
// Function to print a lower or upper bound.
|
p.printOperand(*operandItr++);
|
||||||
auto print_bound = [&](ArrayRef<Attribute> bound_types, size_t idx) {
|
|
||||||
IntegerAttr type = bound_types[idx].dyn_cast<IntegerAttr>();
|
|
||||||
if (type.getValue().getSExtValue() == 0) {
|
|
||||||
// Bound is an integer attribute.
|
|
||||||
auto bound_idx = idx / 2;
|
|
||||||
auto is_ub = idx % 2;
|
|
||||||
IntegerAttr bound = op.getAttrOfType<IntegerAttr>(
|
|
||||||
KrnlIterateOp::getBoundAttrName(bound_idx, is_ub));
|
|
||||||
p << bound.getValue().getSExtValue();
|
|
||||||
} else {
|
|
||||||
// Bound is an operand.
|
|
||||||
p.printOperand(*next_operand_bound);
|
|
||||||
next_operand_bound = std::next(next_operand_bound);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto induction_variables = op.bodyRegion().front().getArguments();
|
|
||||||
ArrayRef<Attribute> bound_types =
|
|
||||||
op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundTypesAttrName())
|
|
||||||
.getValue();
|
|
||||||
|
|
||||||
// Print input loop operands, induction variables and their ranges.
|
|
||||||
for (size_t i = 0; i < num_input_loops; i++) {
|
|
||||||
if (i != 0)
|
|
||||||
p << ", ";
|
|
||||||
|
|
||||||
p.printOperand(*std::next(input_loop_begin, i));
|
|
||||||
p << " -> ";
|
p << " -> ";
|
||||||
|
p.printOperand(var);
|
||||||
// Print induction variable block argument.
|
|
||||||
p.printOperand(induction_variables[i]);
|
|
||||||
p << " = ";
|
p << " = ";
|
||||||
|
onnf::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "max", p);
|
||||||
print_bound(bound_types, 2 * i); // Print lower bound.
|
|
||||||
p << " to ";
|
p << " to ";
|
||||||
print_bound(bound_types, 2 * i + 1); // Print upper bound.
|
onnf::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "min", p);
|
||||||
|
delimiter = ", ";
|
||||||
}
|
}
|
||||||
p << ")";
|
|
||||||
|
|
||||||
|
p << ")";
|
||||||
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
|
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
|
||||||
/*printBlockTerminators=*/false);
|
/*printBlockTerminators=*/false);
|
||||||
}
|
}
|
||||||
|
@ -250,80 +200,109 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) {
|
||||||
ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
||||||
auto builder = parser.getBuilder();
|
auto builder = parser.getBuilder();
|
||||||
auto context = builder.getContext();
|
auto context = builder.getContext();
|
||||||
onnf::KrnlDialectOperandParser operand_parser(parser);
|
onnf::KrnlDialectOperandParser operandParser(parser);
|
||||||
|
|
||||||
// Parse optimized loops:
|
// Parse optimized loops:
|
||||||
SmallVector<OpAsmParser::OperandType, 4> num_optimized_loops;
|
SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs;
|
||||||
if (parser.parseOperandList(
|
if (parser.parseOperandList(
|
||||||
num_optimized_loops, OpAsmParser::Delimiter::Paren) ||
|
optimizedLoopRefs, OpAsmParser::Delimiter::Paren) ||
|
||||||
parser.resolveOperands(num_optimized_loops,
|
parser.resolveOperands(optimizedLoopRefs,
|
||||||
LoopType::get(result.getContext()), result.operands))
|
LoopType::get(result.getContext()), result.operands))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Record how many optimized loops did we parse.
|
// Record how many optimized loops did we parse.
|
||||||
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
|
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
|
||||||
builder.getI64IntegerAttr(num_optimized_loops.size()));
|
builder.getI64IntegerAttr(optimizedLoopRefs.size()));
|
||||||
|
|
||||||
// Parse input loops and their lower and upper bounds.
|
// Parse input loops and their lower and upper bounds.
|
||||||
SmallVector<OpAsmParser::OperandType, 4> in_loop_refs, induction_var_refs;
|
SmallVector<OpAsmParser::OperandType, 4> inductionVarRefs;
|
||||||
SmallVector<Value*, 4> in_loop_operands, operand_bounds;
|
SmallVector<Attribute, 4> boundMaps;
|
||||||
SmallVector<Attribute, 4> bound_types;
|
|
||||||
SmallVector<IntegerAttr, 4> const_bounds;
|
|
||||||
|
|
||||||
if (parser.parseKeyword("with") || parser.parseLParen())
|
if (parser.parseKeyword("with") || parser.parseLParen())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// A function to parse a lower or upper bound.
|
// A function to parse a lower or upper bound.
|
||||||
auto parse_bound = [&result, &builder, &operand_parser, &parser, &bound_types,
|
auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps](
|
||||||
&operand_bounds, &const_bounds](
|
bool isUpper) -> ParseResult {
|
||||||
bool is_ub, size_t bound_pair_count) -> ParseResult {
|
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if
|
||||||
|
// the map has multiple results.
|
||||||
|
bool failedToParsedMinMax =
|
||||||
|
failed(parser.parseOptionalKeyword(isUpper ? "min" : "max"));
|
||||||
|
|
||||||
// Try parse an SSA operand.
|
// Try parse an SSA operand.
|
||||||
Value* bound;
|
if (succeeded(operandParser.ParseOptionalOperand(
|
||||||
operand_parser.ParseOptionalOperand(builder.getIndexType(), bound);
|
builder.getIndexType(), result.operands))) {
|
||||||
|
AffineMap map = builder.getSymbolIdentityMap();
|
||||||
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
if (bound != nullptr) {
|
// Bound is not an SSA id, then it must be an integer.
|
||||||
// Parsed an SSA id as bound.
|
// Parse an integer constant attribute.
|
||||||
operand_bounds.emplace_back(bound);
|
// Get the attribute location.
|
||||||
// Record bound_type as an operand type.
|
llvm::SMLoc attrLoc = parser.getCurrentLocation();
|
||||||
bound_types.emplace_back(builder.getI32IntegerAttr(0));
|
Attribute boundAttr;
|
||||||
} else {
|
llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer;
|
||||||
// Bound is not an SSA id, then it must be an integer.
|
if (parser.parseAttribute(
|
||||||
// Parse an integer constant attribute.
|
boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer))
|
||||||
IntegerAttr boundAttr;
|
return failure();
|
||||||
if (parser.parseAttribute(boundAttr, builder.getIndexType(),
|
|
||||||
KrnlIterateOp::getBoundAttrName(bound_pair_count, is_ub),
|
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
|
||||||
result.attributes))
|
unsigned currentNumOperands = result.operands.size();
|
||||||
|
unsigned numDims = 0;
|
||||||
|
if (parseDimAndSymbolList(parser, result.operands, numDims))
|
||||||
return failure();
|
return failure();
|
||||||
const_bounds.emplace_back(
|
|
||||||
builder.getIntegerAttr(builder.getIndexType(), boundAttr.getValue()));
|
|
||||||
|
|
||||||
// Record that the bound_type is a constant integer attribute.
|
auto map = affineMapAttr.getValue();
|
||||||
bound_types.emplace_back(builder.getI32IntegerAttr(1));
|
if (map.getNumDims() != numDims)
|
||||||
|
return parser.emitError(parser.getNameLoc(),
|
||||||
|
"dim operand count and integer set dim count must match");
|
||||||
|
|
||||||
|
unsigned numDimAndSymbolOperands =
|
||||||
|
result.operands.size() - currentNumOperands;
|
||||||
|
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
|
||||||
|
return parser.emitError(parser.getNameLoc(),
|
||||||
|
"symbol operand count and integer set symbol count must match");
|
||||||
|
|
||||||
|
// If the map has multiple results, make sure that we parsed the min/max
|
||||||
|
// prefix.
|
||||||
|
if (map.getNumResults() > 1 && failedToParsedMinMax) {
|
||||||
|
if (isUpper)
|
||||||
|
return parser.emitError(attrLoc,
|
||||||
|
"upper loop bound affine map with multiple "
|
||||||
|
"results requires 'min' prefix");
|
||||||
|
return parser.emitError(attrLoc,
|
||||||
|
"lower loop bound affine mapwith "
|
||||||
|
"multiple results requires 'max' prefix");
|
||||||
|
}
|
||||||
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
|
||||||
|
AffineMap map =
|
||||||
|
builder.getConstantAffineMap(integerAttr.getValue().getSExtValue());
|
||||||
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool keep_parsing; // Do we keep parsing loops/bounds?
|
bool keepParsing; // Do we keep parsing loops/bounds?
|
||||||
size_t bound_pair_count = 0; // Record the number of bound pairs parsed.
|
|
||||||
do {
|
do {
|
||||||
// Parse an input loop operand;
|
// Parse an input loop operand;
|
||||||
Value* in_loop_operand;
|
operandParser.ParseOperand(LoopType::get(context), result.operands);
|
||||||
operand_parser.ParseOperand(LoopType::get(context), in_loop_operand);
|
|
||||||
in_loop_operands.emplace_back(in_loop_operand);
|
|
||||||
|
|
||||||
parser.parseArrow();
|
parser.parseArrow();
|
||||||
|
|
||||||
// Parse induction variable.
|
// Parse induction variable.
|
||||||
OpAsmParser::OperandType induction_var;
|
OpAsmParser::OperandType inductionVar;
|
||||||
if (parser.parseRegionArgument(induction_var) || parser.parseEqual())
|
if (parser.parseRegionArgument(inductionVar) || parser.parseEqual())
|
||||||
return failure();
|
return failure();
|
||||||
induction_var_refs.emplace_back(induction_var);
|
inductionVarRefs.emplace_back(inductionVar);
|
||||||
|
|
||||||
// Parse bound par (min to max).
|
// Parse bound par (min to max).
|
||||||
if (parse_bound(false, bound_pair_count) || parser.parseKeyword("to") ||
|
if (parseBound(/*isUpper=*/false) || parser.parseKeyword("to") ||
|
||||||
parse_bound(true, bound_pair_count))
|
parseBound(/*isUpper=*/true))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
bound_pair_count++;
|
|
||||||
// We may fail to parse a comma if an operand bound is followed by
|
// We may fail to parse a comma if an operand bound is followed by
|
||||||
// a comma and the next input loop operand, in which case
|
// a comma and the next input loop operand, in which case
|
||||||
// the entire "{operand bound}, {input_loop_operand}" sequence will
|
// the entire "{operand bound}, {input_loop_operand}" sequence will
|
||||||
|
@ -331,33 +310,19 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
||||||
parser.parseOptionalComma();
|
parser.parseOptionalComma();
|
||||||
|
|
||||||
// If we don't see a RParen token, we keep parsing.
|
// If we don't see a RParen token, we keep parsing.
|
||||||
keep_parsing = failed(parser.parseOptionalRParen());
|
keepParsing = failed(parser.parseOptionalRParen());
|
||||||
} while (keep_parsing);
|
} while (keepParsing);
|
||||||
|
|
||||||
// At this point, there shouldn't be any operands left to parse.
|
// At this point, there shouldn't be any operands left to parse.
|
||||||
if (operand_parser.has_operand_left())
|
if (operandParser.hasOperandLeft())
|
||||||
return parser.emitError(parser.getCurrentLocation());
|
return parser.emitError(parser.getCurrentLocation());
|
||||||
|
result.addAttribute(
|
||||||
|
KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps));
|
||||||
|
|
||||||
// Record how many input loops did we parse.
|
|
||||||
result.addOperands(in_loop_operands);
|
|
||||||
result.addAttribute(KrnlIterateOp::getNumInputLoopsAttrName(),
|
|
||||||
builder.getI64IntegerAttr(in_loop_operands.size()));
|
|
||||||
|
|
||||||
// Add operand bounds to the list of operands of current operation.
|
|
||||||
result.addOperands(operand_bounds);
|
|
||||||
|
|
||||||
// A list of 2N elements where the (2n) and (2n+1) th element specifies
|
|
||||||
// whether the lower and upper bound of the n'th induction variable is stored
|
|
||||||
// as an operand or as an attribute. N being the number of input loops
|
|
||||||
// specified in this krnl.iterate operation.
|
|
||||||
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
|
|
||||||
builder.getArrayAttr(bound_types));
|
|
||||||
|
|
||||||
// Parse the schedule body region.
|
|
||||||
Region* region = result.addRegion();
|
Region* region = result.addRegion();
|
||||||
SmallVector<Type, 4> induction_var_types(
|
SmallVector<Type, 4> inductionVarTypes(
|
||||||
induction_var_refs.size(), builder.getIndexType());
|
inductionVarRefs.size(), builder.getIndexType());
|
||||||
if (parser.parseRegion(*region, induction_var_refs, induction_var_types))
|
if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Ensure iterate region is closed off with krnl.terminate.
|
// Ensure iterate region is closed off with krnl.terminate.
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
|
||||||
#include "src/compiler/dialect/krnl/krnl_types.hpp"
|
#include "src/compiler/dialect/krnl/krnl_types.hpp"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
|
||||||
def Krnl_Dialect : Dialect {
|
def Krnl_Dialect : Dialect {
|
||||||
let name = "krnl";
|
let name = "krnl";
|
||||||
let cppNamespace = "";
|
let cppNamespace = "";
|
||||||
|
@ -119,17 +120,14 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
OpBuilder<"Builder *builder, OperationState &result, "
|
||||||
"ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops, "
|
"KrnlIterateOperandPack operandPack">
|
||||||
"ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds, "
|
|
||||||
"ArrayRef<int> bound_types">
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
|
// In krnl.iterate operation, operands are stored as such
|
||||||
// In krnl.iterate operation, three types of SSA values are stored:
|
|
||||||
// - Optimized krnl.loops.
|
// - Optimized krnl.loops.
|
||||||
// - Input krnl.loops.
|
// - Input krnl.loops and their operand bounds. (TODO(Tian) explain better how we store them).
|
||||||
// - SSA value based induction variable bound (parametric bound).
|
|
||||||
// We record the number of optimized and input loops to separate these three
|
// We record the number of optimized and input loops to separate these three
|
||||||
// group of operands out.
|
// group of operands out.
|
||||||
static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; }
|
static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; }
|
||||||
|
@ -143,32 +141,8 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
||||||
return num_optimized_loops;
|
return num_optimized_loops;
|
||||||
}
|
}
|
||||||
|
|
||||||
static StringRef getNumInputLoopsAttrName() { return "num_input_loops"; }
|
// Get name of the attribute for storing bound represented using affine maps.
|
||||||
|
static StringRef getBoundsAttrName() { return "bounds"; }
|
||||||
int64_t getNumInputLoops() {
|
|
||||||
auto num_loops =
|
|
||||||
getAttrOfType<IntegerAttr>(
|
|
||||||
getNumInputLoopsAttrName())
|
|
||||||
.getValue()
|
|
||||||
.getSExtValue();
|
|
||||||
return num_loops;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Constant bounds are stored here as a list attribute.
|
|
||||||
static StringRef getConstantBoundsAttrName() { return "constant_bounds"; }
|
|
||||||
|
|
||||||
// Store type of each bound as three types:
|
|
||||||
// - 0 = constant attribute.
|
|
||||||
// - 1 = operand type.
|
|
||||||
// - 2 = affine maps (TODO).
|
|
||||||
static StringRef getBoundTypesAttrName() { return "bound_types"; }
|
|
||||||
|
|
||||||
// Get dynamic attribute name for the i-th lower and upper bound.
|
|
||||||
static std::string getBoundAttrName(int64_t i, bool is_ub) {
|
|
||||||
std::string bound_type = is_ub ? "_ub" : "_lb";
|
|
||||||
std::string bound_idx = std::to_string(i);
|
|
||||||
return "__bound_" + bound_idx + bound_type;
|
|
||||||
}
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let printer = [{ return ::print(p, *this); }];
|
let printer = [{ return ::print(p, *this); }];
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
//===------------------ parser_helper.cpp - MLIR Operations ---------------===//
|
|
||||||
//
|
|
||||||
// Copyright 2019 The IBM Research Authors.
|
|
||||||
//
|
|
||||||
// =============================================================================
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include "parser_helper.hpp"
|
|
||||||
|
|
||||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
|
||||||
|
|
||||||
namespace onnf {
|
|
||||||
|
|
||||||
mlir::ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
|
||||||
mlir::Type operand_type, mlir::Value*& operand) {
|
|
||||||
// If operand queue is empty, parse more operands and cache them.
|
|
||||||
if (_operand_ref_queue.empty()) {
|
|
||||||
// Parse operand types:
|
|
||||||
llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> operand_refs;
|
|
||||||
_parser.parseOperandList(operand_refs);
|
|
||||||
|
|
||||||
// Record operands:
|
|
||||||
for (auto& operand_ref : operand_refs)
|
|
||||||
_operand_ref_queue.emplace(operand_ref);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we parsed some operand reference(s), resolve the ref to an operand:
|
|
||||||
if (!_operand_ref_queue.empty()) {
|
|
||||||
auto operand_ref = _operand_ref_queue.front();
|
|
||||||
_operand_ref_queue.pop();
|
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value*, 1> operands;
|
|
||||||
_parser.resolveOperand(operand_ref, operand_type, operands);
|
|
||||||
operand = operands.front();
|
|
||||||
return mlir::success();
|
|
||||||
} else {
|
|
||||||
operand = nullptr;
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::ParseResult KrnlDialectOperandParser::ParseOperand(
|
|
||||||
mlir::Type operand_type, mlir::Value*& operand) {
|
|
||||||
ParseOptionalOperand(operand_type, operand);
|
|
||||||
if (operand == nullptr)
|
|
||||||
return _parser.emitError(
|
|
||||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
|
||||||
return mlir::success();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnf
|
|
|
@ -1,46 +0,0 @@
|
||||||
//===------------------ parser_helper.hpp - MLIR Operations ---------------===//
|
|
||||||
//
|
|
||||||
// Copyright 2019 The IBM Research Authors.
|
|
||||||
//
|
|
||||||
// =============================================================================
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <queue>
|
|
||||||
|
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/IR/Dialect.h"
|
|
||||||
#include "mlir/IR/OpDefinition.h"
|
|
||||||
#include "mlir/IR/OpImplementation.h"
|
|
||||||
#include "mlir/IR/StandardTypes.h"
|
|
||||||
|
|
||||||
namespace onnf {
|
|
||||||
|
|
||||||
class KrnlDialectOperandParser {
|
|
||||||
public:
|
|
||||||
KrnlDialectOperandParser(mlir::OpAsmParser& parser)
|
|
||||||
: _parser(parser), _builder(parser.getBuilder()){};
|
|
||||||
|
|
||||||
// Parse an optional operand.
|
|
||||||
mlir::ParseResult ParseOptionalOperand(
|
|
||||||
mlir::Type operand_type, mlir::Value*& operand);
|
|
||||||
|
|
||||||
// Parse a required operand.
|
|
||||||
mlir::ParseResult ParseOperand(
|
|
||||||
mlir::Type operand_type, mlir::Value*& operand);
|
|
||||||
|
|
||||||
// Do we have more operands to parse?
|
|
||||||
bool has_operand_left() { return !_operand_ref_queue.empty(); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
mlir::OpAsmParser& _parser;
|
|
||||||
|
|
||||||
mlir::Builder& _builder;
|
|
||||||
|
|
||||||
// A queue storing the parsed SSA id references.
|
|
||||||
std::queue<mlir::OpAsmParser::OperandType> _operand_ref_queue;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace onnf
|
|
|
@ -16,6 +16,7 @@
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
|
||||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
|
|
||||||
|
@ -43,13 +44,12 @@ static MemRefType convertTensorToMemRef(TensorType type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Insert an allocation and deallocation for the given MemRefType.
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
static Value* insertAllocAndDealloc(
|
static Value* insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
MemRefType type, Location loc, PatternRewriter& rewriter,
|
PatternRewriter& rewriter, Value* oldMemRef = nullptr) {
|
||||||
Value *oldMemRef = nullptr) {
|
|
||||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||||
AllocOp alloc;
|
AllocOp alloc;
|
||||||
if (oldMemRef) {
|
if (oldMemRef) {
|
||||||
SmallVector<Value *, 4> allocOperands;
|
SmallVector<Value*, 4> allocOperands;
|
||||||
auto memRefShape = type.getShape();
|
auto memRefShape = type.getShape();
|
||||||
for (int i = 0; i < memRefShape.size(); ++i)
|
for (int i = 0; i < memRefShape.size(); ++i)
|
||||||
if (memRefShape[i] < 0)
|
if (memRefShape[i] < 0)
|
||||||
|
@ -95,7 +95,7 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
||||||
// dimensions with the result at this pre-optimization phase.
|
// dimensions with the result at this pre-optimization phase.
|
||||||
// TODO: verify that dimensions match.
|
// TODO: verify that dimensions match.
|
||||||
// TODO: can the dimension of the result differ after optimizations?
|
// TODO: can the dimension of the result differ after optimizations?
|
||||||
Value *alloc;
|
Value* alloc;
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||||
else
|
else
|
||||||
|
@ -122,33 +122,22 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
Block& optimizationBlock = optimizedLoopsOp.region().front();
|
Block& optimizationBlock = optimizedLoopsOp.region().front();
|
||||||
|
|
||||||
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
// Iterate over the loop nest.
|
// Iterate over the loop nest.
|
||||||
// TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape
|
// TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape
|
||||||
// to KrnlIterateOp instead.
|
// to KrnlIterateOp instead.
|
||||||
SmallVector<Value*, 8> operandBounds;
|
|
||||||
SmallVector<int64_t, 8> constBounds;
|
|
||||||
SmallVector<int, 16> boundTypes;
|
|
||||||
for (int i = 0; i < rank; ++i) {
|
for (int i = 0; i < rank; ++i) {
|
||||||
if (memRefShape[i] < 0) {
|
if (memRefShape[i] < 0) {
|
||||||
// This is a dynamic value, hence use operands.
|
pack.pushConstantBound(0);
|
||||||
// Lower bound
|
pack.pushOperandBound(
|
||||||
constBounds.push_back(0);
|
|
||||||
boundTypes.push_back(0);
|
|
||||||
// Upper bound
|
|
||||||
operandBounds.push_back(
|
|
||||||
rewriter.create<DimOp>(loc, operands[0], i).getResult());
|
rewriter.create<DimOp>(loc, operands[0], i).getResult());
|
||||||
boundTypes.push_back(1);
|
|
||||||
} else {
|
} else {
|
||||||
// Lower bound
|
pack.pushConstantBound(0);
|
||||||
constBounds.push_back(0);
|
pack.pushConstantBound(memRefShape[i]);
|
||||||
boundTypes.push_back(0);
|
|
||||||
// Upper bound
|
|
||||||
constBounds.push_back(memRefShape[i]);
|
|
||||||
boundTypes.push_back(0);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, originalLoops,
|
|
||||||
optimizedLoops, operandBounds, constBounds, boundTypes);
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
Block& iterationBlock = iterateOp.bodyRegion().front();
|
Block& iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
// Now perform the insertions into the body of the
|
// Now perform the insertions into the body of the
|
||||||
|
@ -169,14 +158,12 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
||||||
SmallVector<Value*, 4> loopIVs;
|
SmallVector<Value*, 4> loopIVs;
|
||||||
for (auto arg : iterationBlock.getArguments())
|
for (auto arg : iterationBlock.getArguments())
|
||||||
loopIVs.push_back(arg);
|
loopIVs.push_back(arg);
|
||||||
auto loadedFirstVal =
|
auto loadedFirstVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
||||||
rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs);
|
||||||
auto loadedSecondVal =
|
|
||||||
rewriter.create<LoadOp>(loc, operands[1], loopIVs);
|
|
||||||
|
|
||||||
// TODO: Choose type of the Add for now use the Float Add.
|
// TODO: Choose type of the Add for now use the Float Add.
|
||||||
auto addOpResult = rewriter.create<AddFOp>(
|
auto addOpResult =
|
||||||
loc, loadedFirstVal, loadedSecondVal);
|
rewriter.create<AddFOp>(loc, loadedFirstVal, loadedSecondVal);
|
||||||
|
|
||||||
// Store result in the resulting array.
|
// Store result in the resulting array.
|
||||||
rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs);
|
rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs);
|
||||||
|
@ -209,8 +196,8 @@ struct TensorTypeConverter : public TypeConverter {
|
||||||
/// inputs. Once unranked results can be handled gracefully this
|
/// inputs. Once unranked results can be handled gracefully this
|
||||||
/// override needs to be removed in favour of the original MLIR one.]
|
/// override needs to be removed in favour of the original MLIR one.]
|
||||||
bool isSignatureLegal(FunctionType funcType) {
|
bool isSignatureLegal(FunctionType funcType) {
|
||||||
return llvm::all_of(funcType.getInputs(),
|
return llvm::all_of(
|
||||||
[this](Type type) { return isLegal(type); });
|
funcType.getInputs(), [this](Type type) { return isLegal(type); });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -272,8 +259,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
// With the target and rewrite patterns defined, we can now attempt the
|
// With the target and rewrite patterns defined, we can now attempt the
|
||||||
// conversion. The conversion will signal failure if any of our `illegal`
|
// conversion. The conversion will signal failure if any of our `illegal`
|
||||||
// operations were not converted successfully.
|
// operations were not converted successfully.
|
||||||
if (failed(applyPartialConversion(
|
if (failed(applyPartialConversion(module, target, patterns)))
|
||||||
module, target, patterns)))
|
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,9 +17,10 @@ class Pass;
|
||||||
|
|
||||||
std::unique_ptr<Pass> createShapeInferencePass();
|
std::unique_ptr<Pass> createShapeInferencePass();
|
||||||
|
|
||||||
/// Pass for lowering frontend dialects to Krnl IR dialect.
|
/// Add pass for lowering to Krnl IR.
|
||||||
std::unique_ptr<mlir::Pass> createLowerToKrnlPass();
|
std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||||
|
|
||||||
// TODO: Add pass for lowering to LLVM IR.
|
/// Pass for lowering frontend dialects to Krnl IR dialect.
|
||||||
|
std::unique_ptr<Pass> createLowerKrnlPass();
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
|
@ -75,7 +75,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
||||||
auto results = terminator_op->getOperandTypes();
|
auto results = terminator_op->getOperandTypes();
|
||||||
f.setType(FunctionType::get(f.getType().getInputs(),
|
f.setType(FunctionType::get(f.getType().getInputs(),
|
||||||
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,8 +6,7 @@ target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT})
|
||||||
|
|
||||||
target_link_libraries(onnf-opt compiler ${MLIRLibs})
|
target_link_libraries(onnf-opt compiler ${MLIRLibs})
|
||||||
whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs})
|
whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs})
|
||||||
whole_archive_link_onnf(onnf-opt onnf_lower_frontend)
|
whole_archive_link_onnf(onnf-opt onnf_transform onnf_lower_frontend onnf_shape_inference)
|
||||||
whole_archive_link_onnf(onnf-opt onnf_shape_inference)
|
|
||||||
|
|
||||||
# TODO: need to investigate how to whole-archive link compiler pass to onnf-opt.
|
# TODO: need to investigate how to whole-archive link compiler pass to onnf-opt.
|
||||||
target_link_libraries(onnf-opt compiler)
|
target_link_libraries(onnf-opt compiler)
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
add_library(onnf_transform lower_krnl.cpp)
|
||||||
|
|
||||||
|
target_include_directories(onnf_transform
|
||||||
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
|
${ONNF_SRC_ROOT})
|
||||||
|
target_link_libraries(onnf_transform ${MLIRLibs})
|
||||||
|
add_dependencies(onnf_transform gen_krnl_ops)
|
|
@ -0,0 +1,161 @@
|
||||||
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||||
|
#include "src/compiler/pass/passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Krnl to Affine Rewrite Patterns: KrnlIterate operation.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
||||||
|
using OpRewritePattern<KrnlIterateOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(
|
||||||
|
KrnlIterateOp iterateOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto boundMapAttrs =
|
||||||
|
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
|
||||||
|
.getValue();
|
||||||
|
auto operandItr =
|
||||||
|
iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops();
|
||||||
|
SmallVector<AffineForOp, 4> nestedForOps;
|
||||||
|
for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) {
|
||||||
|
// Consume input loop operand, currently do not do anything with it.
|
||||||
|
operandItr++;
|
||||||
|
|
||||||
|
// Organize operands into lower/upper bounds in affine.for ready formats.
|
||||||
|
SmallVector<Value*, 4> lbOperands, ubOperands;
|
||||||
|
AffineMap lbMap, ubMap;
|
||||||
|
for (int boundType = 0; boundType < 2; boundType++) {
|
||||||
|
auto& operands = boundType == 0 ? lbOperands : ubOperands;
|
||||||
|
auto& map = boundType == 0 ? lbMap : ubMap;
|
||||||
|
map = boundMapAttrs[boundIdx + boundType]
|
||||||
|
.cast<AffineMapAttr>()
|
||||||
|
.getValue();
|
||||||
|
operands.insert(
|
||||||
|
operands.end(), operandItr, operandItr + map.getNumInputs());
|
||||||
|
std::advance(operandItr, map.getNumInputs());
|
||||||
|
}
|
||||||
|
|
||||||
|
nestedForOps.emplace_back(rewriter.create<AffineForOp>(
|
||||||
|
iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap));
|
||||||
|
rewriter.setInsertionPoint(nestedForOps.back().getBody(),
|
||||||
|
nestedForOps.back().getBody()->begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace induction variable references from those introduced by a
|
||||||
|
// single krnl.iterate to those introduced by multiple affine.for
|
||||||
|
// operations.
|
||||||
|
for (size_t i = 0; i < nestedForOps.size() - 1; i++) {
|
||||||
|
auto iterateIV = iterateOp.bodyRegion().front().getArgument(0);
|
||||||
|
auto forIV = nestedForOps[i].getBody()->getArgument(0);
|
||||||
|
iterateIV->replaceAllUsesWith(forIV);
|
||||||
|
iterateOp.bodyRegion().front().eraseArgument(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pop krnl.iterate body region block arguments, leave the last one
|
||||||
|
// for convenience (it'll be taken care of by region inlining).
|
||||||
|
while (iterateOp.bodyRegion().front().getNumArguments() > 1)
|
||||||
|
iterateOp.bodyRegion().front().eraseArgument(0);
|
||||||
|
|
||||||
|
// Transfer krnl.iterate region to innermost for op.
|
||||||
|
auto innermostForOp = nestedForOps.back();
|
||||||
|
innermostForOp.region().getBlocks().clear();
|
||||||
|
rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(),
|
||||||
|
innermostForOp.region().end());
|
||||||
|
|
||||||
|
rewriter.eraseOp(iterateOp);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Krnl to Affine Rewrite Patterns: KrnlTerminator operation.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(
|
||||||
|
KrnlTerminatorOp op, PatternRewriter& rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Krnl to Affine Rewrite Patterns: KrnlDefineLoops operation.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(
|
||||||
|
KrnlDefineLoopsOp op, PatternRewriter& rewriter) const override {
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(
|
||||||
|
KrnlOptimizeLoopsOp op, PatternRewriter& rewriter) const override {
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KrnlToAffineLoweringPass
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// This is a partial lowering to affine loops of the krnl dialect operations.
|
||||||
|
/// At this stage the dialect will contain standard operations as well like
|
||||||
|
/// add and multiply, this pass will leave these operations intact.
|
||||||
|
namespace {
|
||||||
|
struct KrnlToAffineLoweringPass
|
||||||
|
: public FunctionPass<KrnlToAffineLoweringPass> {
|
||||||
|
void runOnFunction() final;
|
||||||
|
};
|
||||||
|
} // end anonymous namespace.
|
||||||
|
|
||||||
|
void KrnlToAffineLoweringPass::runOnFunction() {
|
||||||
|
auto function = getFunction();
|
||||||
|
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
|
||||||
|
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
|
||||||
|
// We expect IR to be free of Krnl Dialect Ops.
|
||||||
|
target.addIllegalDialect<KrnlOpsDialect>();
|
||||||
|
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||||
|
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext());
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(getFunction(), target, patterns)))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
|
||||||
|
return std::make_unique<KrnlToAffineLoweringPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<KrnlToAffineLoweringPass> pass(
|
||||||
|
"lower-krnl", "Lower Krnl dialect.");
|
16
src/main.cpp
16
src/main.cpp
|
@ -28,6 +28,7 @@
|
||||||
|
|
||||||
#include <boost/program_options.hpp>
|
#include <boost/program_options.hpp>
|
||||||
|
|
||||||
|
#include "llvm/Bitcode/BitcodeWriter.h"
|
||||||
#include "llvm/Support/FileUtilities.h"
|
#include "llvm/Support/FileUtilities.h"
|
||||||
#include "llvm/Support/Regex.h"
|
#include "llvm/Support/Regex.h"
|
||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
@ -38,6 +39,8 @@
|
||||||
#include "src/compiler/pass/passes.hpp"
|
#include "src/compiler/pass/passes.hpp"
|
||||||
|
|
||||||
#include "mlir/Analysis/Verifier.h"
|
#include "mlir/Analysis/Verifier.h"
|
||||||
|
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||||
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
@ -125,7 +128,20 @@ int main(int ac, char* av[]) {
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
|
pm.addPass(mlir::createLowerKrnlPass());
|
||||||
|
pm.addPass(mlir::createLowerAffinePass());
|
||||||
|
pm.addPass(mlir::createLowerToCFGPass());
|
||||||
|
pm.addPass(mlir::createLowerToLLVMPass());
|
||||||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.run(*module);
|
pm.run(*module);
|
||||||
|
|
||||||
|
// Write LLVM bitcode to disk.
|
||||||
|
std::error_code EC;
|
||||||
|
llvm::raw_fd_ostream moduleBitcodeStream(
|
||||||
|
"model.bc", EC, llvm::sys::fs::F_None);
|
||||||
|
llvm::WriteBitcodeToFile(
|
||||||
|
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
||||||
|
moduleBitcodeStream.flush();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
// RUN: onnf-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s
|
||||||
|
// RUN: onnf-opt %s | FileCheck %s
|
||||||
|
|
||||||
|
// GENERIC-DAG: #{{.*}} = () -> (0)
|
||||||
|
// GENERIC-DAG: #{{.*}} = () -> (10)
|
||||||
|
// GENERIC-DAG: #{{.*}} = () -> (1)
|
||||||
|
// GENERIC-DAG: #{{.*}} = () -> (11)
|
||||||
|
// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 - d1)
|
||||||
|
// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 + d1)
|
||||||
|
|
||||||
|
func @simple_iterate(%N : index) {
|
||||||
|
%ii, %ij, %ik = krnl.define_loops 3
|
||||||
|
%oi, %oj, %ok = krnl.optimize_loops {
|
||||||
|
krnl.return_loops %ii, %ij, %ik
|
||||||
|
} : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
|
||||||
|
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||||
|
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
|
||||||
|
// GENERIC-NEXT: "krnl.terminate"() : () -> ()
|
||||||
|
// GENERIC-NEXT: bounds = [#{{.*}}, #{{.*}}, #{{.*}}, #{{.*}}]
|
||||||
|
|
||||||
|
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 1 to 11) {
|
||||||
|
krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 1 to 11) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||||
|
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
|
||||||
|
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) {
|
||||||
|
krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 0 to 10) {
|
||||||
|
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}) ( {
|
||||||
|
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
|
||||||
|
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10) {
|
||||||
|
krnl.iterate(%ok) with (%ik -> %k = 0 to 10) {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||||
|
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
|
||||||
|
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to %{{.*}}, %{{.*}} -> %{{.*}} = 0 to 10) {
|
||||||
|
krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to %N, %ij -> %j = 0 to 10) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func @affine_map_bound(%N : index) {
|
||||||
|
%ii, %ij, %ik = krnl.define_loops 3
|
||||||
|
%oi, %oj, %ok = krnl.optimize_loops {
|
||||||
|
krnl.return_loops %ii, %ij, %ik
|
||||||
|
} : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
|
||||||
|
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||||
|
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
|
||||||
|
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) {
|
||||||
|
krnl.iterate(%oi, %oj) with (%ii -> %i = ()->(0)() to ()->(10)(), %ij -> %j = 0 to 10) {
|
||||||
|
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||||
|
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
|
||||||
|
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = #{{.*}}(%{{.*}}, %{{.*}}) to #{{.*}}(%{{.*}}, %{{.*}})) {
|
||||||
|
krnl.iterate(%ok) with (%ik -> %k = (d0, d1)->(d0 - d1)(%i, %j) to (d0, d1)->(d0 + d1)(%i, %j)) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||||
|
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
|
||||||
|
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = max #map{{.*}}(%{{.*}}, %{{.*}}) to min #map{{.*}}(%{{.*}}, %{{.*}})[%{{.*}}]) {
|
||||||
|
krnl.iterate(%ok) with (%ik -> %k = max (d0, d1)->(d0 - d1, 0)(%i, %j) to min (d0, d1)[s0]->(d0 + d1, s0)(%i, %j)[%N]) {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
Loading…
Reference in New Issue