[MLIR] Refactor Krnl Dialect and Krnl Dialect Lowering (#375)

* Store bounds as affine map attributes & check in test cases with generic printer

* Upgrading MLIR

MLIR is outdated on Buildbot, rebuilding a newer version.

* work with new version of mlir

* check-in parser tests

* custom printer

* nit

* bug fix

* enable custom asm printer test

* enable custom asm printer test

* more consistent variable naming

* test max/min

* variable naming scheme change to MLIR style

* can lower krnl to llvm

* kernel -> llvm

* comments

* bug fix

* try fixing ci

* fix ci

* deactivate model test

* fix lit test

* nit

* fix z buildbot
This commit is contained in:
Tian Jin 2019-11-27 22:56:34 -05:00 committed by Tian Jin
parent 652ce4b7d4
commit b2a1103915
19 changed files with 707 additions and 366 deletions

View File

@ -18,6 +18,8 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
#TODO(eventually enable the following) #TODO(eventually enable the following)
#set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) #set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
include(MLIR.cmake)
add_subdirectory(third_party/onnx) add_subdirectory(third_party/onnx)
add_subdirectory(third_party/benchmark) add_subdirectory(third_party/benchmark)
add_subdirectory(third_party/pybind11) add_subdirectory(third_party/pybind11)
@ -41,7 +43,6 @@ if(Boost_FOUND)
include_directories(${Boost_INCLUDE_DIRS}) include_directories(${Boost_INCLUDE_DIRS})
endif() endif()
include(MLIR.cmake)
add_subdirectory(src/builder) add_subdirectory(src/builder)
add_subdirectory(src/compiler) add_subdirectory(src/compiler)
add_subdirectory(src) add_subdirectory(src)

View File

@ -5,7 +5,7 @@ if(DEFINED ENV{LLVM_SRC})
message(STATUS "LLVM_SRC " ${LLVM_SRC}) message(STATUS "LLVM_SRC " ${LLVM_SRC})
else() else()
message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: " message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: "
${LLVM_SRC}) ${LLVM_SRC})
endif() endif()
else() else()
message(FATAL_ERROR "env variable LLVM_SRC not set") message(FATAL_ERROR "env variable LLVM_SRC not set")
@ -18,7 +18,7 @@ if(DEFINED ENV{LLVM_BUILD})
message(STATUS "LLVM_BUILD " ${LLVM_BUILD}) message(STATUS "LLVM_BUILD " ${LLVM_BUILD})
else() else()
message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: " message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: "
${LLVM_BUILD}) ${LLVM_BUILD})
endif() endif()
else() else()
message(FATAL_ERROR "env variable LLVM_BUILD not set") message(FATAL_ERROR "env variable LLVM_BUILD not set")
@ -39,9 +39,9 @@ set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir) set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir)
set( set(
MLIR_INCLUDE_PATHS MLIR_INCLUDE_PATHS
${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH} ${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH}
) )
include_directories(${MLIR_INCLUDE_PATHS}) include_directories(${MLIR_INCLUDE_PATHS})
# Threading libraries required due to parallel pass execution. # Threading libraries required due to parallel pass execution.
@ -49,9 +49,9 @@ find_package(Threads REQUIRED)
function(find_mlir_lib lib) function(find_mlir_lib lib)
find_library(${lib} find_library(${lib}
NAMES ${lib} NAMES ${lib}
PATHS ${LLVM_PROJECT_LIB} PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH) NO_DEFAULT_PATH)
endfunction(find_mlir_lib) endfunction(find_mlir_lib)
find_mlir_lib(MLIRAffineOps) find_mlir_lib(MLIRAffineOps)
@ -70,6 +70,10 @@ find_mlir_lib(MLIRTransforms)
find_mlir_lib(MLIRTransformUtils) find_mlir_lib(MLIRTransformUtils)
find_mlir_lib(MLIRSupport) find_mlir_lib(MLIRSupport)
find_mlir_lib(MLIROptMain) find_mlir_lib(MLIROptMain)
find_mlir_lib(MLIRTargetLLVMIRModuleTranslation)
find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransformUtils)
find_mlir_lib(MLIRTranslation)
find_mlir_lib(MLIRVectorOps) find_mlir_lib(MLIRVectorOps)
find_mlir_lib(LLVMCore) find_mlir_lib(LLVMCore)
@ -80,46 +84,52 @@ find_mlir_lib(LLVMRemarks)
find_mlir_lib(LLVMIRReader) find_mlir_lib(LLVMIRReader)
find_mlir_lib(LLVMTransformUtils) find_mlir_lib(LLVMTransformUtils)
find_mlir_lib(LLVMBitstreamReader) find_mlir_lib(LLVMBitstreamReader)
find_mlir_lib(LLVMAnalysis)
find_mlir_lib(LLVMBitWriter)
find_mlir_lib(LLVMBitReader)
find_mlir_lib(LLVMMC)
find_mlir_lib(LLVMMCParser)
find_mlir_lib(LLVMObject)
find_mlir_lib(LLVMProfileData)
find_mlir_lib(LLVMDemangle)
set(MLIRLibsOnce set(MLIRLibsOnce
LLVMAnalysis
LLVMAsmParser
LLVMBinaryFormat
LLVMBitReader
LLVMBitstreamReader
LLVMBitWriter
LLVMCore
LLVMIRReader
LLVMMC
LLVMMCParser
LLVMObject
LLVMRemarks
LLVMSupport
LLVMTransformUtils
LLVMProfileData
LLVMDemangle
MLIRAffineOps MLIRAffineOps
MLIRAffineToStandard MLIRAffineToStandard
MLIRAnalysis MLIRAnalysis
MLIRExecutionEngine MLIRExecutionEngine
MLIRIR MLIRIR
MLIRLLVMIR MLIRLLVMIR
MLIRLoopOps
MLIRLoopToStandard MLIRLoopToStandard
MLIROptMain
MLIRParser MLIRParser
MLIRPass MLIRPass
MLIRStandardOps MLIRStandardOps
MLIRStandardToLLVM MLIRStandardToLLVM
MLIRSupport
MLIRTargetLLVMIR MLIRTargetLLVMIR
MLIRTransforms MLIRTargetLLVMIRModuleTranslation
MLIRAffineOps
MLIRAffineToStandard
MLIRAnalysis
MLIRExecutionEngine
MLIRIR
MLIRLLVMIR
MLIRLoopToStandard
MLIRParser
MLIRPass
MLIRStandardOps
MLIRStandardToLLVM
MLIRTargetLLVMIR
MLIRTransforms MLIRTransforms
MLIRTransformUtils MLIRTransformUtils
MLIRLoopOps MLIRTranslation)
MLIRSupport
MLIROptMain
LLVMCore
LLVMSupport
LLVMAsmParser
LLVMIRReader
LLVMTransformUtils
LLVMBinaryFormat
LLVMRemarks
LLVMBitstreamReader)
set(MLIRLibs set(MLIRLibs
${MLIRLibsOnce} ${MLIRLibsOnce}
@ -142,7 +152,7 @@ function(whole_archive_link target lib_dir)
set(link_flags "${link_flags} -L${lib_dir} ") set(link_flags "${link_flags} -L${lib_dir} ")
foreach(LIB ${ARGN}) foreach(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} string(CONCAT link_flags ${link_flags}
"-Wl,-force_load ${lib_dir}/lib${LIB}.a ") "-Wl,-force_load ${lib_dir}/lib${LIB}.a ")
endforeach(LIB) endforeach(LIB)
elseif(MSVC) elseif(MSVC)
foreach(LIB ${ARGN}) foreach(LIB ${ARGN})
@ -170,20 +180,20 @@ function(whole_archive_link_onnf target)
endfunction(whole_archive_link_onnf) endfunction(whole_archive_link_onnf)
set(LLVM_CMAKE_DIR set(LLVM_CMAKE_DIR
"${LLVM_BUILD}/lib/cmake/llvm" "${LLVM_BUILD}/lib/cmake/llvm"
CACHE PATH "Path to LLVM cmake modules") CACHE PATH "Path to LLVM cmake modules")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(AddLLVM) include(AddLLVM)
include(TableGen) include(TableGen)
function(onnf_tablegen ofn) function(onnf_tablegen ofn)
tablegen(MLIR tablegen(MLIR
${ARGV} ${ARGV}
"-I${MLIR_SRC_INCLUDE_PATH}" "-I${MLIR_SRC_INCLUDE_PATH}"
"-I${MLIR_BIN_INCLUDE_PATH}") "-I${MLIR_BIN_INCLUDE_PATH}")
set(TABLEGEN_OUTPUT set(TABLEGEN_OUTPUT
${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
PARENT_SCOPE) PARENT_SCOPE)
endfunction() endfunction()
# Import the pre-built mlir TableGen as an imported exetuable. It is required by # Import the pre-built mlir TableGen as an imported exetuable. It is required by
@ -191,5 +201,5 @@ endfunction()
# table gen utility itself can be detected and cause re-compilation of .td file. # table gen utility itself can be detected and cause re-compilation of .td file.
add_executable(mlir-tblgen IMPORTED) add_executable(mlir-tblgen IMPORTED)
set_property(TARGET mlir-tblgen set_property(TARGET mlir-tblgen
PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen) PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen)
set(MLIR_TABLEGEN_EXE mlir-tblgen) set(MLIR_TABLEGEN_EXE mlir-tblgen)

View File

@ -2,6 +2,7 @@ add_executable(onnf main.cpp)
target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES}) target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES})
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs}) whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
whole_archive_link_onnf(onnf onnf_transform)
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR}) target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})

View File

@ -6,8 +6,8 @@ add_library(
dialect/krnl/krnl_types.hpp dialect/krnl/krnl_types.hpp
dialect/onnx/onnx_ops.cpp dialect/onnx/onnx_ops.cpp
dialect/onnx/onnx_ops.hpp dialect/onnx/onnx_ops.hpp
dialect/krnl/parser_helper.cpp dialect/krnl/krnl_helper.cpp
dialect/krnl/parser_helper.hpp dialect/krnl/krnl_helper.hpp
pass/shape_inference_pass.cpp pass/shape_inference_pass.cpp
pass/shape_inference_interface.hpp pass/shape_inference_interface.hpp
dialect/onnx/onnxop.inc dialect/onnx/onnxop.inc
@ -82,4 +82,5 @@ target_include_directories(onnf_lower_frontend
target_link_libraries(onnf_lower_frontend ${MLIRLibs}) target_link_libraries(onnf_lower_frontend ${MLIRLibs})
add_dependencies(onnf_lower_frontend gen_krnl_ops) add_dependencies(onnf_lower_frontend gen_krnl_ops)
add_subdirectory(transform)
add_subdirectory(tool) add_subdirectory(tool)

View File

@ -0,0 +1,139 @@
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/IR/AffineExpr.h"
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "krnl_helper.hpp"
namespace onnf {
using namespace mlir;
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
const Type& operandType, Value*& operand) {
// If operand queue is empty, parse more operands and cache them.
if (_operandRefQueue.empty()) {
// Parse operand types:
llvm::SmallVector<OpAsmParser::OperandType, 2> operand_refs;
_parser.parseOperandList(operand_refs);
// Record operands:
for (auto& operand_ref : operand_refs)
_operandRefQueue.emplace(operand_ref);
}
// If we parsed some operand reference(s), resolve the ref to an operand:
if (!_operandRefQueue.empty()) {
auto operand_ref = _operandRefQueue.front();
_operandRefQueue.pop();
llvm::SmallVector<Value*, 1> operands;
_parser.resolveOperand(operand_ref, operandType, operands);
operand = operands.front();
return success();
} else {
operand = nullptr;
return failure();
}
}
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
Value* operand = nullptr;
if (ParseOptionalOperand(operandType, operand))
return failure();
operandList.emplace_back(operand);
return success();
}
ParseResult KrnlDialectOperandParser::ParseOperand(
const Type& operandType, Value*& operand) {
if (ParseOptionalOperand(operandType, operand))
return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand.");
return success();
}
ParseResult KrnlDialectOperandParser::ParseOperand(
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
if (ParseOptionalOperand(operandType, operandList))
return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand.");
return success();
}
void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
unsigned numSymbols, OpAsmPrinter& p) {
p << '(';
p.printOperands(begin, begin + numDims);
p << ')';
if (numSymbols) {
p << '[';
p.printOperands(begin + numDims, begin + numDims + numSymbols);
p << ']';
}
begin = std::next(begin, numDims + numSymbols);
}
void printBound(AffineMapAttr boundMap,
Operation::operand_iterator& boundOperandsBeg, const char* prefix,
OpAsmPrinter& p) {
AffineMap map = boundMap.getValue();
// Check if this bound should be printed using custom assembly form.
// The decision to restrict printing custom assembly form to trivial cases
// comes from the will to roundtrip MLIR binary -> text -> binary in a
// lossless way.
// Therefore, custom assembly form parsing and printing is only supported for
// zero-operand constant maps and single symbol operand identity maps.
if (map.getNumResults() == 1) {
AffineExpr expr = map.getResult(0);
// Print constant bound.
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
p << constExpr.getValue();
return;
}
}
// Print bound that consists of a single SSA symbol if the map is over a
// single symbol.
if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
p.printOperand(*(boundOperandsBeg++));
return;
}
}
} else {
// Map has multiple results. Print 'min' or 'max' prefix.
p << prefix << ' ';
}
// Print the map and its operands.
p << boundMap;
printDimAndSymbolList(
boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p);
}
} // namespace onnf
namespace mlir {
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
if (boundMaps.size() % 2 == 0)
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
AffineMap map = builder.getConstantAffineMap(bound);
boundMaps.emplace_back(AffineMapAttr::get(map));
}
void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) {
if (boundMaps.size() % 2 == 0)
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
AffineMap map = builder.getSymbolIdentityMap();
boundMaps.emplace_back(AffineMapAttr::get(map));
_operands.emplace_back(operand);
}
} // namespace mlir

View File

@ -0,0 +1,102 @@
#pragma once
#include <queue>
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
namespace onnf {
class KrnlDialectOperandParser {
public:
explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser)
: _parser(parser), _builder(parser.getBuilder()){};
// Parse an optional operand.
mlir::ParseResult ParseOptionalOperand(
const mlir::Type& operandType, mlir::Value*& operand);
// Parse an optional operand and push it to an operand list.
mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType,
llvm::SmallVectorImpl<mlir::Value*>& operandList);
// Parse a required operand.
mlir::ParseResult ParseOperand(
const mlir::Type& operandType, mlir::Value*& operand);
// Parse a required operand and push it to an operand list.
mlir::ParseResult ParseOperand(const mlir::Type& operandType,
llvm::SmallVectorImpl<mlir::Value*>& operandList);
// Do we have more operands to parse?
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
private:
mlir::OpAsmParser& _parser;
mlir::Builder& _builder;
// A queue storing the parsed SSA id references.
std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
};
// Adapted from:
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
// Main difference is that it advances the iterator `begin` as it consumes
// dimension and symbol operands.
void printDimAndSymbolList(mlir::Operation::operand_iterator& begin,
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p);
// Adapted from:
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
// Main difference is that it advances the iterator `boundOperandsBeg` as it
// prints bound.
void printBound(mlir::AffineMapAttr boundMap,
mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix,
mlir::OpAsmPrinter& p);
} // namespace onnf
namespace mlir {
struct KrnlIterateOperandPack {
KrnlIterateOperandPack(mlir::Builder& builder,
llvm::ArrayRef<mlir::Value*> inputLoops,
llvm::ArrayRef<mlir::Value*> optimizedLoops)
: builder(builder),
inputLoops(inputLoops),
optimizedLoops(optimizedLoops) {
_operands.insert(
_operands.end(), optimizedLoops.begin(), optimizedLoops.end());
}
void pushConstantBound(int64_t bound);
void pushOperandBound(mlir::Value* operand);
llvm::SmallVector<mlir::Value*, 8> getOperands() const { return _operands; }
mlir::ArrayAttr getAttributes() const {
return builder.getArrayAttr(boundMaps);
}
size_t getNumOptimizedLoops() const { return optimizedLoops.size(); }
size_t getNumInputLoops() const { return inputLoops.size(); }
private:
int _boundIdx = 0;
llvm::SmallVector<mlir::Value*, 8> _operands;
llvm::SmallVector<mlir::Attribute, 8> boundMaps;
llvm::ArrayRef<mlir::Value*> inputLoops, optimizedLoops;
mlir::Builder& builder;
};
} // namespace mlir

View File

@ -9,10 +9,10 @@
#include <iostream> #include <iostream>
#include <queue> #include <queue>
#include "src/compiler/dialect/krnl/parser_helper.hpp"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallBitVector.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
@ -24,6 +24,8 @@
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
#include "krnl_ops.hpp" #include "krnl_ops.hpp"
using namespace mlir; using namespace mlir;
@ -52,26 +54,25 @@ void KrnlDefineLoopsOp::build(
} }
void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) { void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) {
auto num_loop_attr = op.getAttrOfType<IntegerAttr>(op.getNumLoopsAttrName()); auto numLoopAttr =
p << "krnl.define_loops " << num_loop_attr.getValue().getSExtValue(); op.getAttrOfType<IntegerAttr>(KrnlDefineLoopsOp::getNumLoopsAttrName());
p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue();
} }
ParseResult parseKrnlDefineLoopsOp( ParseResult parseKrnlDefineLoopsOp(
OpAsmParser& parser, OperationState& result) { OpAsmParser& parser, OperationState& result) {
// Parse the attribute indicating number of loops defined. // Parse the attribute indicating number of loops defined.
IntegerAttr num_loops; IntegerAttr numLoops;
auto& builder = parser.getBuilder(); auto& builder = parser.getBuilder();
auto int32_type = builder.getIntegerType(64); auto intType = builder.getIntegerType(64);
if (parser.parseAttribute(num_loops, int32_type, if (parser.parseAttribute(numLoops, intType,
KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes)) KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes))
return failure(); return failure();
auto loop_types = llvm::SmallVector<Type, 4>( auto loopTypes = llvm::SmallVector<Type, 4>(
num_loops.getValue().getSExtValue(), LoopType::get(builder.getContext())); numLoops.getValue().getSExtValue(), LoopType::get(builder.getContext()));
if (parser.addTypesToList(loop_types, result.types)) if (parser.addTypesToList(loopTypes, result.types))
return failure(); return failure();
return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -142,39 +143,14 @@ ParseResult parseKrnlOptimizeLoopsOp(
* %i0 = 10 to N : %i1 = M to 20 * %i0 = 10 to N : %i1 = M to 20
*/ */
void KrnlIterateOp::build(Builder* builder, OperationState& result, void KrnlIterateOp::build(Builder* builder, OperationState& result,
ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops, KrnlIterateOperandPack operandPack) {
ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds,
ArrayRef<int> bound_types) {
// Record optimized loops and the number of such loops. // Record optimized loops and the number of such loops.
result.addOperands(optimized_loops); result.addOperands(operandPack.getOperands());
result.addAttribute(
KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes());
result.addAttribute(getNumOptimizedLoopsAttrName(), result.addAttribute(getNumOptimizedLoopsAttrName(),
builder->getI64IntegerAttr(optimized_loops.size())); builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops()));
// Record input loops and the number of such loops.
result.addOperands(input_loops);
result.addAttribute(getNumInputLoopsAttrName(),
builder->getI64IntegerAttr(input_loops.size()));
// Record bound either as attribute or from operand list.
auto next_operand_bound = operand_bounds.begin();
auto next_const_bound = const_bounds.begin();
for (size_t i = 0; i < bound_types.size(); i++) {
auto bound_type = bound_types[i];
if (bound_type == 0) {
// Constant bound.
result.addAttribute(getBoundAttrName(i / 2, i % 2),
builder->getI64IntegerAttr(*next_const_bound));
next_const_bound = std::next(next_const_bound);
} else {
// Operand bound.
result.addOperands(*next_operand_bound);
next_operand_bound = std::next(next_operand_bound);
}
}
// Record bound types as attribute:
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
builder->getI32ArrayAttr(bound_types));
// Create a region and a block for the body. The arguments of the region are // Create a region and a block for the body. The arguments of the region are
// the loop induction variables; there can be multiple induction variables // the loop induction variables; there can be multiple induction variables
@ -182,7 +158,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result,
Region* bodyRegion = result.addRegion(); Region* bodyRegion = result.addRegion();
auto* body = new Block(); auto* body = new Block();
auto body_args = llvm::SmallVector<Type, 4>( auto body_args = llvm::SmallVector<Type, 4>(
input_loops.size(), IndexType::get(builder->getContext())); operandPack.getNumInputLoops(), IndexType::get(builder->getContext()));
body->addArguments(body_args); body->addArguments(body_args);
bodyRegion->push_back(body); bodyRegion->push_back(body);
@ -192,57 +168,31 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result,
void print(OpAsmPrinter& p, KrnlIterateOp& op) { void print(OpAsmPrinter& p, KrnlIterateOp& op) {
p << "krnl.iterate("; p << "krnl.iterate(";
// Print optimized loops: // Print optimized loops:
auto num_optimized_loops = op.getNumOptimizedLoops(); auto numOptimizedLoops = op.getNumOptimizedLoops();
p.printOperands(op.operand_begin(), op.operand_begin() + num_optimized_loops); p.printOperands(op.operand_begin(), op.operand_begin() + numOptimizedLoops);
p << ") with ("; p << ") with (";
// Set up iterator to input loops: auto inductionVars = op.bodyRegion().begin()->getArguments();
auto num_input_loops = op.getNumInputLoops(); auto boundItr =
auto input_loop_begin = op.operand_begin() + num_optimized_loops; op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
.getValue()
.begin();
auto operandItr = op.operand_begin() + numOptimizedLoops;
// Set up iterators to operand bounds. std::string delimiter;
auto next_operand_bound = input_loop_begin + num_input_loops; for (auto& var : inductionVars) {
p << delimiter;
// Function to print a lower or upper bound. p.printOperand(*operandItr++);
auto print_bound = [&](ArrayRef<Attribute> bound_types, size_t idx) {
IntegerAttr type = bound_types[idx].dyn_cast<IntegerAttr>();
if (type.getValue().getSExtValue() == 0) {
// Bound is an integer attribute.
auto bound_idx = idx / 2;
auto is_ub = idx % 2;
IntegerAttr bound = op.getAttrOfType<IntegerAttr>(
KrnlIterateOp::getBoundAttrName(bound_idx, is_ub));
p << bound.getValue().getSExtValue();
} else {
// Bound is an operand.
p.printOperand(*next_operand_bound);
next_operand_bound = std::next(next_operand_bound);
}
};
auto induction_variables = op.bodyRegion().front().getArguments();
ArrayRef<Attribute> bound_types =
op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundTypesAttrName())
.getValue();
// Print input loop operands, induction variables and their ranges.
for (size_t i = 0; i < num_input_loops; i++) {
if (i != 0)
p << ", ";
p.printOperand(*std::next(input_loop_begin, i));
p << " -> "; p << " -> ";
p.printOperand(var);
// Print induction variable block argument.
p.printOperand(induction_variables[i]);
p << " = "; p << " = ";
onnf::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "max", p);
print_bound(bound_types, 2 * i); // Print lower bound.
p << " to "; p << " to ";
print_bound(bound_types, 2 * i + 1); // Print upper bound. onnf::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "min", p);
delimiter = ", ";
} }
p << ")";
p << ")";
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false, p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false); /*printBlockTerminators=*/false);
} }
@ -250,80 +200,109 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) {
ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
auto builder = parser.getBuilder(); auto builder = parser.getBuilder();
auto context = builder.getContext(); auto context = builder.getContext();
onnf::KrnlDialectOperandParser operand_parser(parser); onnf::KrnlDialectOperandParser operandParser(parser);
// Parse optimized loops: // Parse optimized loops:
SmallVector<OpAsmParser::OperandType, 4> num_optimized_loops; SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs;
if (parser.parseOperandList( if (parser.parseOperandList(
num_optimized_loops, OpAsmParser::Delimiter::Paren) || optimizedLoopRefs, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(num_optimized_loops, parser.resolveOperands(optimizedLoopRefs,
LoopType::get(result.getContext()), result.operands)) LoopType::get(result.getContext()), result.operands))
return failure(); return failure();
// Record how many optimized loops did we parse. // Record how many optimized loops did we parse.
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(), result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
builder.getI64IntegerAttr(num_optimized_loops.size())); builder.getI64IntegerAttr(optimizedLoopRefs.size()));
// Parse input loops and their lower and upper bounds. // Parse input loops and their lower and upper bounds.
SmallVector<OpAsmParser::OperandType, 4> in_loop_refs, induction_var_refs; SmallVector<OpAsmParser::OperandType, 4> inductionVarRefs;
SmallVector<Value*, 4> in_loop_operands, operand_bounds; SmallVector<Attribute, 4> boundMaps;
SmallVector<Attribute, 4> bound_types;
SmallVector<IntegerAttr, 4> const_bounds;
if (parser.parseKeyword("with") || parser.parseLParen()) if (parser.parseKeyword("with") || parser.parseLParen())
return failure(); return failure();
// A function to parse a lower or upper bound. // A function to parse a lower or upper bound.
auto parse_bound = [&result, &builder, &operand_parser, &parser, &bound_types, auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps](
&operand_bounds, &const_bounds]( bool isUpper) -> ParseResult {
bool is_ub, size_t bound_pair_count) -> ParseResult { // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
// the map has multiple results.
bool failedToParsedMinMax =
failed(parser.parseOptionalKeyword(isUpper ? "min" : "max"));
// Try parse an SSA operand. // Try parse an SSA operand.
Value* bound; if (succeeded(operandParser.ParseOptionalOperand(
operand_parser.ParseOptionalOperand(builder.getIndexType(), bound); builder.getIndexType(), result.operands))) {
AffineMap map = builder.getSymbolIdentityMap();
boundMaps.emplace_back(AffineMapAttr::get(map));
return success();
}
if (bound != nullptr) { // Bound is not an SSA id, then it must be an integer.
// Parsed an SSA id as bound. // Parse an integer constant attribute.
operand_bounds.emplace_back(bound); // Get the attribute location.
// Record bound_type as an operand type. llvm::SMLoc attrLoc = parser.getCurrentLocation();
bound_types.emplace_back(builder.getI32IntegerAttr(0)); Attribute boundAttr;
} else { llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer;
// Bound is not an SSA id, then it must be an integer. if (parser.parseAttribute(
// Parse an integer constant attribute. boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer))
IntegerAttr boundAttr; return failure();
if (parser.parseAttribute(boundAttr, builder.getIndexType(),
KrnlIterateOp::getBoundAttrName(bound_pair_count, is_ub), if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
result.attributes)) unsigned currentNumOperands = result.operands.size();
unsigned numDims = 0;
if (parseDimAndSymbolList(parser, result.operands, numDims))
return failure(); return failure();
const_bounds.emplace_back(
builder.getIntegerAttr(builder.getIndexType(), boundAttr.getValue()));
// Record that the bound_type is a constant integer attribute. auto map = affineMapAttr.getValue();
bound_types.emplace_back(builder.getI32IntegerAttr(1)); if (map.getNumDims() != numDims)
return parser.emitError(parser.getNameLoc(),
"dim operand count and integer set dim count must match");
unsigned numDimAndSymbolOperands =
result.operands.size() - currentNumOperands;
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
return parser.emitError(parser.getNameLoc(),
"symbol operand count and integer set symbol count must match");
// If the map has multiple results, make sure that we parsed the min/max
// prefix.
if (map.getNumResults() > 1 && failedToParsedMinMax) {
if (isUpper)
return parser.emitError(attrLoc,
"upper loop bound affine map with multiple "
"results requires 'min' prefix");
return parser.emitError(attrLoc,
"lower loop bound affine mapwith "
"multiple results requires 'max' prefix");
}
boundMaps.emplace_back(AffineMapAttr::get(map));
return success();
}
if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
AffineMap map =
builder.getConstantAffineMap(integerAttr.getValue().getSExtValue());
boundMaps.emplace_back(AffineMapAttr::get(map));
} }
}; };
bool keep_parsing; // Do we keep parsing loops/bounds? bool keepParsing; // Do we keep parsing loops/bounds?
size_t bound_pair_count = 0; // Record the number of bound pairs parsed.
do { do {
// Parse an input loop operand; // Parse an input loop operand;
Value* in_loop_operand; operandParser.ParseOperand(LoopType::get(context), result.operands);
operand_parser.ParseOperand(LoopType::get(context), in_loop_operand);
in_loop_operands.emplace_back(in_loop_operand);
parser.parseArrow(); parser.parseArrow();
// Parse induction variable. // Parse induction variable.
OpAsmParser::OperandType induction_var; OpAsmParser::OperandType inductionVar;
if (parser.parseRegionArgument(induction_var) || parser.parseEqual()) if (parser.parseRegionArgument(inductionVar) || parser.parseEqual())
return failure(); return failure();
induction_var_refs.emplace_back(induction_var); inductionVarRefs.emplace_back(inductionVar);
// Parse bound par (min to max). // Parse bound par (min to max).
if (parse_bound(false, bound_pair_count) || parser.parseKeyword("to") || if (parseBound(/*isUpper=*/false) || parser.parseKeyword("to") ||
parse_bound(true, bound_pair_count)) parseBound(/*isUpper=*/true))
return failure(); return failure();
bound_pair_count++;
// We may fail to parse a comma if an operand bound is followed by // We may fail to parse a comma if an operand bound is followed by
// a comma and the next input loop operand, in which case // a comma and the next input loop operand, in which case
// the entire "{operand bound}, {input_loop_operand}" sequence will // the entire "{operand bound}, {input_loop_operand}" sequence will
@ -331,33 +310,19 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
parser.parseOptionalComma(); parser.parseOptionalComma();
// If we don't see a RParen token, we keep parsing. // If we don't see a RParen token, we keep parsing.
keep_parsing = failed(parser.parseOptionalRParen()); keepParsing = failed(parser.parseOptionalRParen());
} while (keep_parsing); } while (keepParsing);
// At this point, there shouldn't be any operands left to parse. // At this point, there shouldn't be any operands left to parse.
if (operand_parser.has_operand_left()) if (operandParser.hasOperandLeft())
return parser.emitError(parser.getCurrentLocation()); return parser.emitError(parser.getCurrentLocation());
result.addAttribute(
KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps));
// Record how many input loops did we parse.
result.addOperands(in_loop_operands);
result.addAttribute(KrnlIterateOp::getNumInputLoopsAttrName(),
builder.getI64IntegerAttr(in_loop_operands.size()));
// Add operand bounds to the list of operands of current operation.
result.addOperands(operand_bounds);
// A list of 2N elements where the (2n) and (2n+1) th element specifies
// whether the lower and upper bound of the n'th induction variable is stored
// as an operand or as an attribute. N being the number of input loops
// specified in this krnl.iterate operation.
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
builder.getArrayAttr(bound_types));
// Parse the schedule body region.
Region* region = result.addRegion(); Region* region = result.addRegion();
SmallVector<Type, 4> induction_var_types( SmallVector<Type, 4> inductionVarTypes(
induction_var_refs.size(), builder.getIndexType()); inductionVarRefs.size(), builder.getIndexType());
if (parser.parseRegion(*region, induction_var_refs, induction_var_types)) if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes))
return failure(); return failure();
// Ensure iterate region is closed off with krnl.terminate. // Ensure iterate region is closed off with krnl.terminate.

View File

@ -14,6 +14,7 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
#include "src/compiler/dialect/krnl/krnl_types.hpp" #include "src/compiler/dialect/krnl/krnl_types.hpp"
namespace mlir { namespace mlir {

View File

@ -8,6 +8,7 @@
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
def Krnl_Dialect : Dialect { def Krnl_Dialect : Dialect {
let name = "krnl"; let name = "krnl";
let cppNamespace = ""; let cppNamespace = "";
@ -119,17 +120,14 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
let builders = [ let builders = [
OpBuilder<"Builder *builder, OperationState &result, " OpBuilder<"Builder *builder, OperationState &result, "
"ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops, " "KrnlIterateOperandPack operandPack">
"ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds, "
"ArrayRef<int> bound_types">
]; ];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
// In krnl.iterate operation, operands are stored as such
// In krnl.iterate operation, three types of SSA values are stored:
// - Optimized krnl.loops. // - Optimized krnl.loops.
// - Input krnl.loops. // - Input krnl.loops and their operand bounds. (TODO(Tian) explain better how we store them).
// - SSA value based induction variable bound (parametric bound).
// We record the number of optimized and input loops to separate these three // We record the number of optimized and input loops to separate these three
// group of operands out. // group of operands out.
static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; } static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; }
@ -143,32 +141,8 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
return num_optimized_loops; return num_optimized_loops;
} }
static StringRef getNumInputLoopsAttrName() { return "num_input_loops"; } // Get name of the attribute for storing bound represented using affine maps.
static StringRef getBoundsAttrName() { return "bounds"; }
int64_t getNumInputLoops() {
auto num_loops =
getAttrOfType<IntegerAttr>(
getNumInputLoopsAttrName())
.getValue()
.getSExtValue();
return num_loops;
}
// Constant bounds are stored here as a list attribute.
static StringRef getConstantBoundsAttrName() { return "constant_bounds"; }
// Store type of each bound as three types:
// - 0 = constant attribute.
// - 1 = operand type.
// - 2 = affine maps (TODO).
static StringRef getBoundTypesAttrName() { return "bound_types"; }
// Get dynamic attribute name for the i-th lower and upper bound.
static std::string getBoundAttrName(int64_t i, bool is_ub) {
std::string bound_type = is_ub ? "_ub" : "_lb";
std::string bound_idx = std::to_string(i);
return "__bound_" + bound_idx + bound_type;
}
}]; }];
let printer = [{ return ::print(p, *this); }]; let printer = [{ return ::print(p, *this); }];

View File

@ -1,52 +0,0 @@
//===------------------ parser_helper.cpp - MLIR Operations ---------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
#include "parser_helper.hpp"
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
namespace onnf {
mlir::ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
mlir::Type operand_type, mlir::Value*& operand) {
// If operand queue is empty, parse more operands and cache them.
if (_operand_ref_queue.empty()) {
// Parse operand types:
llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> operand_refs;
_parser.parseOperandList(operand_refs);
// Record operands:
for (auto& operand_ref : operand_refs)
_operand_ref_queue.emplace(operand_ref);
}
// If we parsed some operand reference(s), resolve the ref to an operand:
if (!_operand_ref_queue.empty()) {
auto operand_ref = _operand_ref_queue.front();
_operand_ref_queue.pop();
llvm::SmallVector<mlir::Value*, 1> operands;
_parser.resolveOperand(operand_ref, operand_type, operands);
operand = operands.front();
return mlir::success();
} else {
operand = nullptr;
return mlir::failure();
}
}
mlir::ParseResult KrnlDialectOperandParser::ParseOperand(
mlir::Type operand_type, mlir::Value*& operand) {
ParseOptionalOperand(operand_type, operand);
if (operand == nullptr)
return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand.");
return mlir::success();
}
} // namespace onnf

View File

@ -1,46 +0,0 @@
//===------------------ parser_helper.hpp - MLIR Operations ---------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
#pragma once
#include <queue>
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
namespace onnf {
class KrnlDialectOperandParser {
public:
KrnlDialectOperandParser(mlir::OpAsmParser& parser)
: _parser(parser), _builder(parser.getBuilder()){};
// Parse an optional operand.
mlir::ParseResult ParseOptionalOperand(
mlir::Type operand_type, mlir::Value*& operand);
// Parse a required operand.
mlir::ParseResult ParseOperand(
mlir::Type operand_type, mlir::Value*& operand);
// Do we have more operands to parse?
bool has_operand_left() { return !_operand_ref_queue.empty(); }
private:
mlir::OpAsmParser& _parser;
mlir::Builder& _builder;
// A queue storing the parsed SSA id references.
std::queue<mlir::OpAsmParser::OperandType> _operand_ref_queue;
};
} // namespace onnf

View File

@ -16,6 +16,7 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
#include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp"
@ -43,13 +44,12 @@ static MemRefType convertTensorToMemRef(TensorType type) {
} }
/// Insert an allocation and deallocation for the given MemRefType. /// Insert an allocation and deallocation for the given MemRefType.
static Value* insertAllocAndDealloc( static Value* insertAllocAndDealloc(MemRefType type, Location loc,
MemRefType type, Location loc, PatternRewriter& rewriter, PatternRewriter& rewriter, Value* oldMemRef = nullptr) {
Value *oldMemRef = nullptr) {
// Put together alloc operands for any dynamic dimensions of the memref. // Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc; AllocOp alloc;
if (oldMemRef) { if (oldMemRef) {
SmallVector<Value *, 4> allocOperands; SmallVector<Value*, 4> allocOperands;
auto memRefShape = type.getShape(); auto memRefShape = type.getShape();
for (int i = 0; i < memRefShape.size(); ++i) for (int i = 0; i < memRefShape.size(); ++i)
if (memRefShape[i] < 0) if (memRefShape[i] < 0)
@ -95,7 +95,7 @@ struct ONNXAddOpLowering : public ConversionPattern {
// dimensions with the result at this pre-optimization phase. // dimensions with the result at this pre-optimization phase.
// TODO: verify that dimensions match. // TODO: verify that dimensions match.
// TODO: can the dimension of the result differ after optimizations? // TODO: can the dimension of the result differ after optimizations?
Value *alloc; Value* alloc;
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter); alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
else else
@ -122,33 +122,22 @@ struct ONNXAddOpLowering : public ConversionPattern {
} }
Block& optimizationBlock = optimizedLoopsOp.region().front(); Block& optimizationBlock = optimizedLoopsOp.region().front();
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest. // Iterate over the loop nest.
// TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape
// to KrnlIterateOp instead. // to KrnlIterateOp instead.
SmallVector<Value*, 8> operandBounds;
SmallVector<int64_t, 8> constBounds;
SmallVector<int, 16> boundTypes;
for (int i = 0; i < rank; ++i) { for (int i = 0; i < rank; ++i) {
if (memRefShape[i] < 0) { if (memRefShape[i] < 0) {
// This is a dynamic value, hence use operands. pack.pushConstantBound(0);
// Lower bound pack.pushOperandBound(
constBounds.push_back(0);
boundTypes.push_back(0);
// Upper bound
operandBounds.push_back(
rewriter.create<DimOp>(loc, operands[0], i).getResult()); rewriter.create<DimOp>(loc, operands[0], i).getResult());
boundTypes.push_back(1);
} else { } else {
// Lower bound pack.pushConstantBound(0);
constBounds.push_back(0); pack.pushConstantBound(memRefShape[i]);
boundTypes.push_back(0);
// Upper bound
constBounds.push_back(memRefShape[i]);
boundTypes.push_back(0);
} }
} }
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, originalLoops,
optimizedLoops, operandBounds, constBounds, boundTypes); auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block& iterationBlock = iterateOp.bodyRegion().front(); Block& iterationBlock = iterateOp.bodyRegion().front();
// Now perform the insertions into the body of the // Now perform the insertions into the body of the
@ -169,14 +158,12 @@ struct ONNXAddOpLowering : public ConversionPattern {
SmallVector<Value*, 4> loopIVs; SmallVector<Value*, 4> loopIVs;
for (auto arg : iterationBlock.getArguments()) for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg); loopIVs.push_back(arg);
auto loadedFirstVal = auto loadedFirstVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
rewriter.create<LoadOp>(loc, operands[0], loopIVs); auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs);
auto loadedSecondVal =
rewriter.create<LoadOp>(loc, operands[1], loopIVs);
// TODO: Choose type of the Add for now use the Float Add. // TODO: Choose type of the Add for now use the Float Add.
auto addOpResult = rewriter.create<AddFOp>( auto addOpResult =
loc, loadedFirstVal, loadedSecondVal); rewriter.create<AddFOp>(loc, loadedFirstVal, loadedSecondVal);
// Store result in the resulting array. // Store result in the resulting array.
rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs); rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs);
@ -209,8 +196,8 @@ struct TensorTypeConverter : public TypeConverter {
/// inputs. Once unranked results can be handled gracefully this /// inputs. Once unranked results can be handled gracefully this
/// override needs to be removed in favour of the original MLIR one.] /// override needs to be removed in favour of the original MLIR one.]
bool isSignatureLegal(FunctionType funcType) { bool isSignatureLegal(FunctionType funcType) {
return llvm::all_of(funcType.getInputs(), return llvm::all_of(
[this](Type type) { return isLegal(type); }); funcType.getInputs(), [this](Type type) { return isLegal(type); });
} }
}; };
@ -272,8 +259,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
// With the target and rewrite patterns defined, we can now attempt the // With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal` // conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully. // operations were not converted successfully.
if (failed(applyPartialConversion( if (failed(applyPartialConversion(module, target, patterns)))
module, target, patterns)))
signalPassFailure(); signalPassFailure();
} }

View File

@ -17,9 +17,10 @@ class Pass;
std::unique_ptr<Pass> createShapeInferencePass(); std::unique_ptr<Pass> createShapeInferencePass();
/// Pass for lowering frontend dialects to Krnl IR dialect. /// Add pass for lowering to Krnl IR.
std::unique_ptr<mlir::Pass> createLowerToKrnlPass(); std::unique_ptr<Pass> createLowerToKrnlPass();
// TODO: Add pass for lowering to LLVM IR. /// Pass for lowering frontend dialects to Krnl IR dialect.
std::unique_ptr<Pass> createLowerKrnlPass();
} // end namespace mlir } // end namespace mlir

View File

@ -75,7 +75,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
if (auto terminator_op = f.getBody().back().getTerminator()) { if (auto terminator_op = f.getBody().back().getTerminator()) {
auto results = terminator_op->getOperandTypes(); auto results = terminator_op->getOperandTypes();
f.setType(FunctionType::get(f.getType().getInputs(), f.setType(FunctionType::get(f.getType().getInputs(),
std::vector<Type>(results.begin(), results.end()), f.getContext())); std::vector<Type>(results.begin(), results.end()), f.getContext()));
} }
} }

View File

@ -6,8 +6,7 @@ target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT})
target_link_libraries(onnf-opt compiler ${MLIRLibs}) target_link_libraries(onnf-opt compiler ${MLIRLibs})
whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs}) whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs})
whole_archive_link_onnf(onnf-opt onnf_lower_frontend) whole_archive_link_onnf(onnf-opt onnf_transform onnf_lower_frontend onnf_shape_inference)
whole_archive_link_onnf(onnf-opt onnf_shape_inference)
# TODO: need to investigate how to whole-archive link compiler pass to onnf-opt. # TODO: need to investigate how to whole-archive link compiler pass to onnf-opt.
target_link_libraries(onnf-opt compiler) target_link_libraries(onnf-opt compiler)

View File

@ -0,0 +1,7 @@
add_library(onnf_transform lower_krnl.cpp)
target_include_directories(onnf_transform
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT})
target_link_libraries(onnf_transform ${MLIRLibs})
add_dependencies(onnf_transform gen_krnl_ops)

View File

@ -0,0 +1,161 @@
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/pass/passes.hpp"
using namespace mlir;
namespace {
//===----------------------------------------------------------------------===//
// Krnl to Affine Rewrite Patterns: KrnlIterate operation.
//===----------------------------------------------------------------------===//
struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
using OpRewritePattern<KrnlIterateOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(
KrnlIterateOp iterateOp, PatternRewriter& rewriter) const override {
auto boundMapAttrs =
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
.getValue();
auto operandItr =
iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops();
SmallVector<AffineForOp, 4> nestedForOps;
for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) {
// Consume input loop operand, currently do not do anything with it.
operandItr++;
// Organize operands into lower/upper bounds in affine.for ready formats.
SmallVector<Value*, 4> lbOperands, ubOperands;
AffineMap lbMap, ubMap;
for (int boundType = 0; boundType < 2; boundType++) {
auto& operands = boundType == 0 ? lbOperands : ubOperands;
auto& map = boundType == 0 ? lbMap : ubMap;
map = boundMapAttrs[boundIdx + boundType]
.cast<AffineMapAttr>()
.getValue();
operands.insert(
operands.end(), operandItr, operandItr + map.getNumInputs());
std::advance(operandItr, map.getNumInputs());
}
nestedForOps.emplace_back(rewriter.create<AffineForOp>(
iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap));
rewriter.setInsertionPoint(nestedForOps.back().getBody(),
nestedForOps.back().getBody()->begin());
}
// Replace induction variable references from those introduced by a
// single krnl.iterate to those introduced by multiple affine.for
// operations.
for (size_t i = 0; i < nestedForOps.size() - 1; i++) {
auto iterateIV = iterateOp.bodyRegion().front().getArgument(0);
auto forIV = nestedForOps[i].getBody()->getArgument(0);
iterateIV->replaceAllUsesWith(forIV);
iterateOp.bodyRegion().front().eraseArgument(0);
}
// Pop krnl.iterate body region block arguments, leave the last one
// for convenience (it'll be taken care of by region inlining).
while (iterateOp.bodyRegion().front().getNumArguments() > 1)
iterateOp.bodyRegion().front().eraseArgument(0);
// Transfer krnl.iterate region to innermost for op.
auto innermostForOp = nestedForOps.back();
innermostForOp.region().getBlocks().clear();
rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(),
innermostForOp.region().end());
rewriter.eraseOp(iterateOp);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// Krnl to Affine Rewrite Patterns: KrnlTerminator operation.
//===----------------------------------------------------------------------===//
class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
public:
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(
KrnlTerminatorOp op, PatternRewriter& rewriter) const override {
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// Krnl to Affine Rewrite Patterns: KrnlDefineLoops operation.
//===----------------------------------------------------------------------===//
class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
public:
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(
KrnlDefineLoopsOp op, PatternRewriter& rewriter) const override {
rewriter.eraseOp(op);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation.
//===----------------------------------------------------------------------===//
class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> {
public:
using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(
KrnlOptimizeLoopsOp op, PatternRewriter& rewriter) const override {
rewriter.eraseOp(op);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// KrnlToAffineLoweringPass
//===----------------------------------------------------------------------===//
/// This is a partial lowering to affine loops of the krnl dialect operations.
/// At this stage the dialect will contain standard operations as well like
/// add and multiply, this pass will leave these operations intact.
namespace {
struct KrnlToAffineLoweringPass
: public FunctionPass<KrnlToAffineLoweringPass> {
void runOnFunction() final;
};
} // end anonymous namespace.
void KrnlToAffineLoweringPass::runOnFunction() {
auto function = getFunction();
ConversionTarget target(getContext());
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
// We expect IR to be free of Krnl Dialect Ops.
target.addIllegalDialect<KrnlOpsDialect>();
OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext());
if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
} // namespace
std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
return std::make_unique<KrnlToAffineLoweringPass>();
}
static PassRegistration<KrnlToAffineLoweringPass> pass(
"lower-krnl", "Lower Krnl dialect.");

View File

@ -28,6 +28,7 @@
#include <boost/program_options.hpp> #include <boost/program_options.hpp>
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/Support/FileUtilities.h" #include "llvm/Support/FileUtilities.h"
#include "llvm/Support/Regex.h" #include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
@ -38,6 +39,8 @@
#include "src/compiler/pass/passes.hpp" #include "src/compiler/pass/passes.hpp"
#include "mlir/Analysis/Verifier.h" #include "mlir/Analysis/Verifier.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
@ -125,7 +128,20 @@ int main(int ac, char* av[]) {
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createLowerToKrnlPass()); pm.addPass(mlir::createLowerToKrnlPass());
pm.addPass(mlir::createLowerKrnlPass());
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(mlir::createLowerToLLVMPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.run(*module); pm.run(*module);
// Write LLVM bitcode to disk.
std::error_code EC;
llvm::raw_fd_ostream moduleBitcodeStream(
"model.bc", EC, llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
moduleBitcodeStream.flush();
return 0; return 0;
} }

75
test/mlir/krnl/ops.mlir Normal file
View File

@ -0,0 +1,75 @@
// RUN: onnf-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s
// RUN: onnf-opt %s | FileCheck %s
// GENERIC-DAG: #{{.*}} = () -> (0)
// GENERIC-DAG: #{{.*}} = () -> (10)
// GENERIC-DAG: #{{.*}} = () -> (1)
// GENERIC-DAG: #{{.*}} = () -> (11)
// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 - d1)
// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 + d1)
func @simple_iterate(%N : index) {
%ii, %ij, %ik = krnl.define_loops 3
%oi, %oj, %ok = krnl.optimize_loops {
krnl.return_loops %ii, %ij, %ik
} : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
// GENERIC-NEXT: "krnl.terminate"() : () -> ()
// GENERIC-NEXT: bounds = [#{{.*}}, #{{.*}}, #{{.*}}, #{{.*}}]
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 1 to 11) {
krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 1 to 11) {
}
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) {
krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 0 to 10) {
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}) ( {
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10) {
krnl.iterate(%ok) with (%ik -> %k = 0 to 10) {
}
}
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to %{{.*}}, %{{.*}} -> %{{.*}} = 0 to 10) {
krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to %N, %ij -> %j = 0 to 10) {
}
return
}
func @affine_map_bound(%N : index) {
%ii, %ij, %ik = krnl.define_loops 3
%oi, %oj, %ok = krnl.optimize_loops {
krnl.return_loops %ii, %ij, %ik
} : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) {
krnl.iterate(%oi, %oj) with (%ii -> %i = ()->(0)() to ()->(10)(), %ij -> %j = 0 to 10) {
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = #{{.*}}(%{{.*}}, %{{.*}}) to #{{.*}}(%{{.*}}, %{{.*}})) {
krnl.iterate(%ok) with (%ik -> %k = (d0, d1)->(d0 - d1)(%i, %j) to (d0, d1)->(d0 + d1)(%i, %j)) {
}
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = max #map{{.*}}(%{{.*}}, %{{.*}}) to min #map{{.*}}(%{{.*}}, %{{.*}})[%{{.*}}]) {
krnl.iterate(%ok) with (%ik -> %k = max (d0, d1)->(d0 - d1, 0)(%i, %j) to min (d0, d1)[s0]->(d0 + d1, s0)(%i, %j)[%N]) {
}
}
return
}