diff --git a/CMakeLists.txt b/CMakeLists.txt index 38f64ae..440ac55 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,6 +18,8 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) #TODO(eventually enable the following) #set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +include(MLIR.cmake) + add_subdirectory(third_party/onnx) add_subdirectory(third_party/benchmark) add_subdirectory(third_party/pybind11) @@ -41,7 +43,6 @@ if(Boost_FOUND) include_directories(${Boost_INCLUDE_DIRS}) endif() -include(MLIR.cmake) add_subdirectory(src/builder) add_subdirectory(src/compiler) add_subdirectory(src) diff --git a/MLIR.cmake b/MLIR.cmake index af9c75e..d8d9da7 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -5,7 +5,7 @@ if(DEFINED ENV{LLVM_SRC}) message(STATUS "LLVM_SRC " ${LLVM_SRC}) else() message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: " - ${LLVM_SRC}) + ${LLVM_SRC}) endif() else() message(FATAL_ERROR "env variable LLVM_SRC not set") @@ -18,7 +18,7 @@ if(DEFINED ENV{LLVM_BUILD}) message(STATUS "LLVM_BUILD " ${LLVM_BUILD}) else() message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: " - ${LLVM_BUILD}) + ${LLVM_BUILD}) endif() else() message(FATAL_ERROR "env variable LLVM_BUILD not set") @@ -39,9 +39,9 @@ set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir) set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir) set( - MLIR_INCLUDE_PATHS - ${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH} - ) + MLIR_INCLUDE_PATHS + ${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH} +) include_directories(${MLIR_INCLUDE_PATHS}) # Threading libraries required due to parallel pass execution. @@ -49,9 +49,9 @@ find_package(Threads REQUIRED) function(find_mlir_lib lib) find_library(${lib} - NAMES ${lib} - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) + NAMES ${lib} + PATHS ${LLVM_PROJECT_LIB} + NO_DEFAULT_PATH) endfunction(find_mlir_lib) find_mlir_lib(MLIRAffineOps) @@ -70,6 +70,10 @@ find_mlir_lib(MLIRTransforms) find_mlir_lib(MLIRTransformUtils) find_mlir_lib(MLIRSupport) find_mlir_lib(MLIROptMain) +find_mlir_lib(MLIRTargetLLVMIRModuleTranslation) +find_mlir_lib(MLIRTargetLLVMIR) +find_mlir_lib(MLIRTransformUtils) +find_mlir_lib(MLIRTranslation) find_mlir_lib(MLIRVectorOps) find_mlir_lib(LLVMCore) @@ -80,46 +84,52 @@ find_mlir_lib(LLVMRemarks) find_mlir_lib(LLVMIRReader) find_mlir_lib(LLVMTransformUtils) find_mlir_lib(LLVMBitstreamReader) +find_mlir_lib(LLVMAnalysis) +find_mlir_lib(LLVMBitWriter) +find_mlir_lib(LLVMBitReader) +find_mlir_lib(LLVMMC) +find_mlir_lib(LLVMMCParser) +find_mlir_lib(LLVMObject) +find_mlir_lib(LLVMProfileData) +find_mlir_lib(LLVMDemangle) + set(MLIRLibsOnce + LLVMAnalysis + LLVMAsmParser + LLVMBinaryFormat + LLVMBitReader + LLVMBitstreamReader + LLVMBitWriter + LLVMCore + LLVMIRReader + LLVMMC + LLVMMCParser + LLVMObject + LLVMRemarks + LLVMSupport + LLVMTransformUtils + LLVMProfileData + LLVMDemangle MLIRAffineOps MLIRAffineToStandard MLIRAnalysis MLIRExecutionEngine MLIRIR MLIRLLVMIR + MLIRLoopOps MLIRLoopToStandard + MLIROptMain MLIRParser MLIRPass MLIRStandardOps MLIRStandardToLLVM + MLIRSupport MLIRTargetLLVMIR - MLIRTransforms - MLIRAffineOps - MLIRAffineToStandard - MLIRAnalysis - MLIRExecutionEngine - MLIRIR - MLIRLLVMIR - MLIRLoopToStandard - MLIRParser - MLIRPass - MLIRStandardOps - MLIRStandardToLLVM - MLIRTargetLLVMIR + MLIRTargetLLVMIRModuleTranslation MLIRTransforms MLIRTransformUtils - MLIRLoopOps - MLIRSupport - MLIROptMain - LLVMCore - LLVMSupport - LLVMAsmParser - LLVMIRReader - LLVMTransformUtils - LLVMBinaryFormat - LLVMRemarks - LLVMBitstreamReader) + MLIRTranslation) set(MLIRLibs ${MLIRLibsOnce} @@ -142,7 +152,7 @@ function(whole_archive_link target lib_dir) set(link_flags "${link_flags} -L${lib_dir} ") foreach(LIB ${ARGN}) string(CONCAT link_flags ${link_flags} - "-Wl,-force_load ${lib_dir}/lib${LIB}.a ") + "-Wl,-force_load ${lib_dir}/lib${LIB}.a ") endforeach(LIB) elseif(MSVC) foreach(LIB ${ARGN}) @@ -170,20 +180,20 @@ function(whole_archive_link_onnf target) endfunction(whole_archive_link_onnf) set(LLVM_CMAKE_DIR - "${LLVM_BUILD}/lib/cmake/llvm" - CACHE PATH "Path to LLVM cmake modules") + "${LLVM_BUILD}/lib/cmake/llvm" + CACHE PATH "Path to LLVM cmake modules") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") include(AddLLVM) include(TableGen) function(onnf_tablegen ofn) tablegen(MLIR - ${ARGV} - "-I${MLIR_SRC_INCLUDE_PATH}" - "-I${MLIR_BIN_INCLUDE_PATH}") + ${ARGV} + "-I${MLIR_SRC_INCLUDE_PATH}" + "-I${MLIR_BIN_INCLUDE_PATH}") set(TABLEGEN_OUTPUT - ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} - PARENT_SCOPE) + ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} + PARENT_SCOPE) endfunction() # Import the pre-built mlir TableGen as an imported exetuable. It is required by @@ -191,5 +201,5 @@ endfunction() # table gen utility itself can be detected and cause re-compilation of .td file. add_executable(mlir-tblgen IMPORTED) set_property(TARGET mlir-tblgen - PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen) + PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen) set(MLIR_TABLEGEN_EXE mlir-tblgen) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c36eb0b..74fa34a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,8 +2,9 @@ add_executable(onnf main.cpp) target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES}) whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs}) +whole_archive_link_onnf(onnf onnf_transform) target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR}) -install(TARGETS onnf DESTINATION bin) \ No newline at end of file +install(TARGETS onnf DESTINATION bin) diff --git a/src/compiler/CMakeLists.txt b/src/compiler/CMakeLists.txt index 480e2e0..46a3ce6 100644 --- a/src/compiler/CMakeLists.txt +++ b/src/compiler/CMakeLists.txt @@ -6,8 +6,8 @@ add_library( dialect/krnl/krnl_types.hpp dialect/onnx/onnx_ops.cpp dialect/onnx/onnx_ops.hpp - dialect/krnl/parser_helper.cpp - dialect/krnl/parser_helper.hpp + dialect/krnl/krnl_helper.cpp + dialect/krnl/krnl_helper.hpp pass/shape_inference_pass.cpp pass/shape_inference_interface.hpp dialect/onnx/onnxop.inc @@ -82,4 +82,5 @@ target_include_directories(onnf_lower_frontend target_link_libraries(onnf_lower_frontend ${MLIRLibs}) add_dependencies(onnf_lower_frontend gen_krnl_ops) +add_subdirectory(transform) add_subdirectory(tool) diff --git a/src/compiler/dialect/krnl/krnl_helper.cpp b/src/compiler/dialect/krnl/krnl_helper.cpp new file mode 100644 index 0000000..b1e6de1 --- /dev/null +++ b/src/compiler/dialect/krnl/krnl_helper.cpp @@ -0,0 +1,139 @@ +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/IR/AffineExpr.h" + +#include "src/compiler/dialect/krnl/krnl_ops.hpp" + +#include "krnl_helper.hpp" + +namespace onnf { + +using namespace mlir; + +ParseResult KrnlDialectOperandParser::ParseOptionalOperand( + const Type& operandType, Value*& operand) { + // If operand queue is empty, parse more operands and cache them. + if (_operandRefQueue.empty()) { + // Parse operand types: + llvm::SmallVector operand_refs; + _parser.parseOperandList(operand_refs); + + // Record operands: + for (auto& operand_ref : operand_refs) + _operandRefQueue.emplace(operand_ref); + } + + // If we parsed some operand reference(s), resolve the ref to an operand: + if (!_operandRefQueue.empty()) { + auto operand_ref = _operandRefQueue.front(); + _operandRefQueue.pop(); + + llvm::SmallVector operands; + _parser.resolveOperand(operand_ref, operandType, operands); + operand = operands.front(); + return success(); + } else { + operand = nullptr; + return failure(); + } +} + +ParseResult KrnlDialectOperandParser::ParseOptionalOperand( + const Type& operandType, llvm::SmallVectorImpl& operandList) { + Value* operand = nullptr; + if (ParseOptionalOperand(operandType, operand)) + return failure(); + + operandList.emplace_back(operand); + return success(); +} + +ParseResult KrnlDialectOperandParser::ParseOperand( + const Type& operandType, Value*& operand) { + if (ParseOptionalOperand(operandType, operand)) + return _parser.emitError( + _parser.getCurrentLocation(), "Expecting an operand."); + return success(); +} + +ParseResult KrnlDialectOperandParser::ParseOperand( + const Type& operandType, llvm::SmallVectorImpl& operandList) { + if (ParseOptionalOperand(operandType, operandList)) + return _parser.emitError( + _parser.getCurrentLocation(), "Expecting an operand."); + + return success(); +} + +void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims, + unsigned numSymbols, OpAsmPrinter& p) { + p << '('; + p.printOperands(begin, begin + numDims); + p << ')'; + + if (numSymbols) { + p << '['; + p.printOperands(begin + numDims, begin + numDims + numSymbols); + p << ']'; + } + + begin = std::next(begin, numDims + numSymbols); +} + +void printBound(AffineMapAttr boundMap, + Operation::operand_iterator& boundOperandsBeg, const char* prefix, + OpAsmPrinter& p) { + AffineMap map = boundMap.getValue(); + + // Check if this bound should be printed using custom assembly form. + // The decision to restrict printing custom assembly form to trivial cases + // comes from the will to roundtrip MLIR binary -> text -> binary in a + // lossless way. + // Therefore, custom assembly form parsing and printing is only supported for + // zero-operand constant maps and single symbol operand identity maps. + if (map.getNumResults() == 1) { + AffineExpr expr = map.getResult(0); + + // Print constant bound. + if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { + if (auto constExpr = expr.dyn_cast()) { + p << constExpr.getValue(); + return; + } + } + + // Print bound that consists of a single SSA symbol if the map is over a + // single symbol. + if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { + if (auto symExpr = expr.dyn_cast()) { + p.printOperand(*(boundOperandsBeg++)); + return; + } + } + } else { + // Map has multiple results. Print 'min' or 'max' prefix. + p << prefix << ' '; + } + + // Print the map and its operands. + p << boundMap; + printDimAndSymbolList( + boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p); +} +} // namespace onnf + +namespace mlir { +void KrnlIterateOperandPack::pushConstantBound(int64_t bound) { + if (boundMaps.size() % 2 == 0) + _operands.emplace_back(inputLoops[boundMaps.size() / 2]); + AffineMap map = builder.getConstantAffineMap(bound); + boundMaps.emplace_back(AffineMapAttr::get(map)); +} + +void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) { + if (boundMaps.size() % 2 == 0) + _operands.emplace_back(inputLoops[boundMaps.size() / 2]); + AffineMap map = builder.getSymbolIdentityMap(); + boundMaps.emplace_back(AffineMapAttr::get(map)); + _operands.emplace_back(operand); +} +} // namespace mlir diff --git a/src/compiler/dialect/krnl/krnl_helper.hpp b/src/compiler/dialect/krnl/krnl_helper.hpp new file mode 100644 index 0000000..8573af0 --- /dev/null +++ b/src/compiler/dialect/krnl/krnl_helper.hpp @@ -0,0 +1,102 @@ +#pragma once + +#include + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" + +namespace onnf { + +class KrnlDialectOperandParser { + public: + explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser) + : _parser(parser), _builder(parser.getBuilder()){}; + + // Parse an optional operand. + mlir::ParseResult ParseOptionalOperand( + const mlir::Type& operandType, mlir::Value*& operand); + + // Parse an optional operand and push it to an operand list. + mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType, + llvm::SmallVectorImpl& operandList); + + // Parse a required operand. + mlir::ParseResult ParseOperand( + const mlir::Type& operandType, mlir::Value*& operand); + + // Parse a required operand and push it to an operand list. + mlir::ParseResult ParseOperand(const mlir::Type& operandType, + llvm::SmallVectorImpl& operandList); + + // Do we have more operands to parse? + bool hasOperandLeft() { return !_operandRefQueue.empty(); } + + private: + mlir::OpAsmParser& _parser; + + mlir::Builder& _builder; + + // A queue storing the parsed SSA id references. + std::queue _operandRefQueue; +}; + +// Adapted from: +// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197 +// Main difference is that it advances the iterator `begin` as it consumes +// dimension and symbol operands. +void printDimAndSymbolList(mlir::Operation::operand_iterator& begin, + unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p); + +// Adapted from: +// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272 +// Main difference is that it advances the iterator `boundOperandsBeg` as it +// prints bound. +void printBound(mlir::AffineMapAttr boundMap, + mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix, + mlir::OpAsmPrinter& p); +} // namespace onnf + +namespace mlir { + +struct KrnlIterateOperandPack { + KrnlIterateOperandPack(mlir::Builder& builder, + llvm::ArrayRef inputLoops, + llvm::ArrayRef optimizedLoops) + : builder(builder), + inputLoops(inputLoops), + optimizedLoops(optimizedLoops) { + _operands.insert( + _operands.end(), optimizedLoops.begin(), optimizedLoops.end()); + } + + void pushConstantBound(int64_t bound); + + void pushOperandBound(mlir::Value* operand); + + llvm::SmallVector getOperands() const { return _operands; } + + mlir::ArrayAttr getAttributes() const { + return builder.getArrayAttr(boundMaps); + } + + size_t getNumOptimizedLoops() const { return optimizedLoops.size(); } + + size_t getNumInputLoops() const { return inputLoops.size(); } + + private: + int _boundIdx = 0; + + llvm::SmallVector _operands; + + llvm::SmallVector boundMaps; + + llvm::ArrayRef inputLoops, optimizedLoops; + + mlir::Builder& builder; +}; + +} // namespace mlir diff --git a/src/compiler/dialect/krnl/krnl_ops.cpp b/src/compiler/dialect/krnl/krnl_ops.cpp index 5dbd16d..ee71925 100644 --- a/src/compiler/dialect/krnl/krnl_ops.cpp +++ b/src/compiler/dialect/krnl/krnl_ops.cpp @@ -9,10 +9,10 @@ #include #include -#include "src/compiler/dialect/krnl/parser_helper.hpp" - #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" @@ -24,6 +24,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "src/compiler/dialect/krnl/krnl_helper.hpp" + #include "krnl_ops.hpp" using namespace mlir; @@ -52,26 +54,25 @@ void KrnlDefineLoopsOp::build( } void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) { - auto num_loop_attr = op.getAttrOfType(op.getNumLoopsAttrName()); - p << "krnl.define_loops " << num_loop_attr.getValue().getSExtValue(); + auto numLoopAttr = + op.getAttrOfType(KrnlDefineLoopsOp::getNumLoopsAttrName()); + p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue(); } ParseResult parseKrnlDefineLoopsOp( OpAsmParser& parser, OperationState& result) { // Parse the attribute indicating number of loops defined. - IntegerAttr num_loops; + IntegerAttr numLoops; auto& builder = parser.getBuilder(); - auto int32_type = builder.getIntegerType(64); - if (parser.parseAttribute(num_loops, int32_type, + auto intType = builder.getIntegerType(64); + if (parser.parseAttribute(numLoops, intType, KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes)) return failure(); - auto loop_types = llvm::SmallVector( - num_loops.getValue().getSExtValue(), LoopType::get(builder.getContext())); - if (parser.addTypesToList(loop_types, result.types)) + auto loopTypes = llvm::SmallVector( + numLoops.getValue().getSExtValue(), LoopType::get(builder.getContext())); + if (parser.addTypesToList(loopTypes, result.types)) return failure(); - - return success(); } //===----------------------------------------------------------------------===// @@ -142,39 +143,14 @@ ParseResult parseKrnlOptimizeLoopsOp( * %i0 = 10 to N : %i1 = M to 20 */ void KrnlIterateOp::build(Builder* builder, OperationState& result, - ArrayRef input_loops, ArrayRef optimized_loops, - ArrayRef operand_bounds, ArrayRef const_bounds, - ArrayRef bound_types) { + KrnlIterateOperandPack operandPack) { // Record optimized loops and the number of such loops. - result.addOperands(optimized_loops); + result.addOperands(operandPack.getOperands()); + result.addAttribute( + KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes()); + result.addAttribute(getNumOptimizedLoopsAttrName(), - builder->getI64IntegerAttr(optimized_loops.size())); - - // Record input loops and the number of such loops. - result.addOperands(input_loops); - result.addAttribute(getNumInputLoopsAttrName(), - builder->getI64IntegerAttr(input_loops.size())); - - // Record bound either as attribute or from operand list. - auto next_operand_bound = operand_bounds.begin(); - auto next_const_bound = const_bounds.begin(); - for (size_t i = 0; i < bound_types.size(); i++) { - auto bound_type = bound_types[i]; - if (bound_type == 0) { - // Constant bound. - result.addAttribute(getBoundAttrName(i / 2, i % 2), - builder->getI64IntegerAttr(*next_const_bound)); - next_const_bound = std::next(next_const_bound); - } else { - // Operand bound. - result.addOperands(*next_operand_bound); - next_operand_bound = std::next(next_operand_bound); - } - } - - // Record bound types as attribute: - result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(), - builder->getI32ArrayAttr(bound_types)); + builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops())); // Create a region and a block for the body. The arguments of the region are // the loop induction variables; there can be multiple induction variables @@ -182,7 +158,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result, Region* bodyRegion = result.addRegion(); auto* body = new Block(); auto body_args = llvm::SmallVector( - input_loops.size(), IndexType::get(builder->getContext())); + operandPack.getNumInputLoops(), IndexType::get(builder->getContext())); body->addArguments(body_args); bodyRegion->push_back(body); @@ -192,57 +168,31 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result, void print(OpAsmPrinter& p, KrnlIterateOp& op) { p << "krnl.iterate("; // Print optimized loops: - auto num_optimized_loops = op.getNumOptimizedLoops(); - p.printOperands(op.operand_begin(), op.operand_begin() + num_optimized_loops); + auto numOptimizedLoops = op.getNumOptimizedLoops(); + p.printOperands(op.operand_begin(), op.operand_begin() + numOptimizedLoops); p << ") with ("; - // Set up iterator to input loops: - auto num_input_loops = op.getNumInputLoops(); - auto input_loop_begin = op.operand_begin() + num_optimized_loops; + auto inductionVars = op.bodyRegion().begin()->getArguments(); + auto boundItr = + op.getAttrOfType(KrnlIterateOp::getBoundsAttrName()) + .getValue() + .begin(); + auto operandItr = op.operand_begin() + numOptimizedLoops; - // Set up iterators to operand bounds. - auto next_operand_bound = input_loop_begin + num_input_loops; - - // Function to print a lower or upper bound. - auto print_bound = [&](ArrayRef bound_types, size_t idx) { - IntegerAttr type = bound_types[idx].dyn_cast(); - if (type.getValue().getSExtValue() == 0) { - // Bound is an integer attribute. - auto bound_idx = idx / 2; - auto is_ub = idx % 2; - IntegerAttr bound = op.getAttrOfType( - KrnlIterateOp::getBoundAttrName(bound_idx, is_ub)); - p << bound.getValue().getSExtValue(); - } else { - // Bound is an operand. - p.printOperand(*next_operand_bound); - next_operand_bound = std::next(next_operand_bound); - } - }; - - auto induction_variables = op.bodyRegion().front().getArguments(); - ArrayRef bound_types = - op.getAttrOfType(KrnlIterateOp::getBoundTypesAttrName()) - .getValue(); - - // Print input loop operands, induction variables and their ranges. - for (size_t i = 0; i < num_input_loops; i++) { - if (i != 0) - p << ", "; - - p.printOperand(*std::next(input_loop_begin, i)); + std::string delimiter; + for (auto& var : inductionVars) { + p << delimiter; + p.printOperand(*operandItr++); p << " -> "; - - // Print induction variable block argument. - p.printOperand(induction_variables[i]); + p.printOperand(var); p << " = "; - - print_bound(bound_types, 2 * i); // Print lower bound. + onnf::printBound((*boundItr++).cast(), operandItr, "max", p); p << " to "; - print_bound(bound_types, 2 * i + 1); // Print upper bound. + onnf::printBound((*boundItr++).cast(), operandItr, "min", p); + delimiter = ", "; } - p << ")"; + p << ")"; p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); } @@ -250,80 +200,109 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) { ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { auto builder = parser.getBuilder(); auto context = builder.getContext(); - onnf::KrnlDialectOperandParser operand_parser(parser); + onnf::KrnlDialectOperandParser operandParser(parser); // Parse optimized loops: - SmallVector num_optimized_loops; + SmallVector optimizedLoopRefs; if (parser.parseOperandList( - num_optimized_loops, OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(num_optimized_loops, + optimizedLoopRefs, OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(optimizedLoopRefs, LoopType::get(result.getContext()), result.operands)) return failure(); // Record how many optimized loops did we parse. result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(), - builder.getI64IntegerAttr(num_optimized_loops.size())); + builder.getI64IntegerAttr(optimizedLoopRefs.size())); // Parse input loops and their lower and upper bounds. - SmallVector in_loop_refs, induction_var_refs; - SmallVector in_loop_operands, operand_bounds; - SmallVector bound_types; - SmallVector const_bounds; + SmallVector inductionVarRefs; + SmallVector boundMaps; if (parser.parseKeyword("with") || parser.parseLParen()) return failure(); // A function to parse a lower or upper bound. - auto parse_bound = [&result, &builder, &operand_parser, &parser, &bound_types, - &operand_bounds, &const_bounds]( - bool is_ub, size_t bound_pair_count) -> ParseResult { + auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps]( + bool isUpper) -> ParseResult { + // 'min' / 'max' prefixes are generally syntactic sugar, but are required if + // the map has multiple results. + bool failedToParsedMinMax = + failed(parser.parseOptionalKeyword(isUpper ? "min" : "max")); + // Try parse an SSA operand. - Value* bound; - operand_parser.ParseOptionalOperand(builder.getIndexType(), bound); + if (succeeded(operandParser.ParseOptionalOperand( + builder.getIndexType(), result.operands))) { + AffineMap map = builder.getSymbolIdentityMap(); + boundMaps.emplace_back(AffineMapAttr::get(map)); + return success(); + } - if (bound != nullptr) { - // Parsed an SSA id as bound. - operand_bounds.emplace_back(bound); - // Record bound_type as an operand type. - bound_types.emplace_back(builder.getI32IntegerAttr(0)); - } else { - // Bound is not an SSA id, then it must be an integer. - // Parse an integer constant attribute. - IntegerAttr boundAttr; - if (parser.parseAttribute(boundAttr, builder.getIndexType(), - KrnlIterateOp::getBoundAttrName(bound_pair_count, is_ub), - result.attributes)) + // Bound is not an SSA id, then it must be an integer. + // Parse an integer constant attribute. + // Get the attribute location. + llvm::SMLoc attrLoc = parser.getCurrentLocation(); + Attribute boundAttr; + llvm::SmallVector tempBoundAttrContainer; + if (parser.parseAttribute( + boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer)) + return failure(); + + if (auto affineMapAttr = boundAttr.dyn_cast()) { + unsigned currentNumOperands = result.operands.size(); + unsigned numDims = 0; + if (parseDimAndSymbolList(parser, result.operands, numDims)) return failure(); - const_bounds.emplace_back( - builder.getIntegerAttr(builder.getIndexType(), boundAttr.getValue())); - // Record that the bound_type is a constant integer attribute. - bound_types.emplace_back(builder.getI32IntegerAttr(1)); + auto map = affineMapAttr.getValue(); + if (map.getNumDims() != numDims) + return parser.emitError(parser.getNameLoc(), + "dim operand count and integer set dim count must match"); + + unsigned numDimAndSymbolOperands = + result.operands.size() - currentNumOperands; + if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) + return parser.emitError(parser.getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // If the map has multiple results, make sure that we parsed the min/max + // prefix. + if (map.getNumResults() > 1 && failedToParsedMinMax) { + if (isUpper) + return parser.emitError(attrLoc, + "upper loop bound affine map with multiple " + "results requires 'min' prefix"); + return parser.emitError(attrLoc, + "lower loop bound affine mapwith " + "multiple results requires 'max' prefix"); + } + boundMaps.emplace_back(AffineMapAttr::get(map)); + return success(); + } + + if (auto integerAttr = boundAttr.dyn_cast()) { + AffineMap map = + builder.getConstantAffineMap(integerAttr.getValue().getSExtValue()); + boundMaps.emplace_back(AffineMapAttr::get(map)); } }; - bool keep_parsing; // Do we keep parsing loops/bounds? - size_t bound_pair_count = 0; // Record the number of bound pairs parsed. + bool keepParsing; // Do we keep parsing loops/bounds? do { // Parse an input loop operand; - Value* in_loop_operand; - operand_parser.ParseOperand(LoopType::get(context), in_loop_operand); - in_loop_operands.emplace_back(in_loop_operand); - + operandParser.ParseOperand(LoopType::get(context), result.operands); parser.parseArrow(); // Parse induction variable. - OpAsmParser::OperandType induction_var; - if (parser.parseRegionArgument(induction_var) || parser.parseEqual()) + OpAsmParser::OperandType inductionVar; + if (parser.parseRegionArgument(inductionVar) || parser.parseEqual()) return failure(); - induction_var_refs.emplace_back(induction_var); + inductionVarRefs.emplace_back(inductionVar); // Parse bound par (min to max). - if (parse_bound(false, bound_pair_count) || parser.parseKeyword("to") || - parse_bound(true, bound_pair_count)) + if (parseBound(/*isUpper=*/false) || parser.parseKeyword("to") || + parseBound(/*isUpper=*/true)) return failure(); - bound_pair_count++; // We may fail to parse a comma if an operand bound is followed by // a comma and the next input loop operand, in which case // the entire "{operand bound}, {input_loop_operand}" sequence will @@ -331,33 +310,19 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { parser.parseOptionalComma(); // If we don't see a RParen token, we keep parsing. - keep_parsing = failed(parser.parseOptionalRParen()); - } while (keep_parsing); + keepParsing = failed(parser.parseOptionalRParen()); + } while (keepParsing); // At this point, there shouldn't be any operands left to parse. - if (operand_parser.has_operand_left()) + if (operandParser.hasOperandLeft()) return parser.emitError(parser.getCurrentLocation()); + result.addAttribute( + KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps)); - // Record how many input loops did we parse. - result.addOperands(in_loop_operands); - result.addAttribute(KrnlIterateOp::getNumInputLoopsAttrName(), - builder.getI64IntegerAttr(in_loop_operands.size())); - - // Add operand bounds to the list of operands of current operation. - result.addOperands(operand_bounds); - - // A list of 2N elements where the (2n) and (2n+1) th element specifies - // whether the lower and upper bound of the n'th induction variable is stored - // as an operand or as an attribute. N being the number of input loops - // specified in this krnl.iterate operation. - result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(), - builder.getArrayAttr(bound_types)); - - // Parse the schedule body region. Region* region = result.addRegion(); - SmallVector induction_var_types( - induction_var_refs.size(), builder.getIndexType()); - if (parser.parseRegion(*region, induction_var_refs, induction_var_types)) + SmallVector inductionVarTypes( + inductionVarRefs.size(), builder.getIndexType()); + if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes)) return failure(); // Ensure iterate region is closed off with krnl.terminate. diff --git a/src/compiler/dialect/krnl/krnl_ops.hpp b/src/compiler/dialect/krnl/krnl_ops.hpp index 8e9ae83..89d4587 100644 --- a/src/compiler/dialect/krnl/krnl_ops.hpp +++ b/src/compiler/dialect/krnl/krnl_ops.hpp @@ -14,6 +14,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" +#include "src/compiler/dialect/krnl/krnl_helper.hpp" #include "src/compiler/dialect/krnl/krnl_types.hpp" namespace mlir { diff --git a/src/compiler/dialect/krnl/krnl_ops.td b/src/compiler/dialect/krnl/krnl_ops.td index e6a4830..76a38d9 100644 --- a/src/compiler/dialect/krnl/krnl_ops.td +++ b/src/compiler/dialect/krnl/krnl_ops.td @@ -8,6 +8,7 @@ include "mlir/IR/OpBase.td" + def Krnl_Dialect : Dialect { let name = "krnl"; let cppNamespace = ""; @@ -119,17 +120,14 @@ def KrnlIterateOp : Op { let builders = [ OpBuilder<"Builder *builder, OperationState &result, " - "ArrayRef input_loops, ArrayRef optimized_loops, " - "ArrayRef operand_bounds, ArrayRef const_bounds, " - "ArrayRef bound_types"> + "KrnlIterateOperandPack operandPack"> ]; let extraClassDeclaration = [{ - - // In krnl.iterate operation, three types of SSA values are stored: + // In krnl.iterate operation, operands are stored as such // - Optimized krnl.loops. - // - Input krnl.loops. - // - SSA value based induction variable bound (parametric bound). + // - Input krnl.loops and their operand bounds. (TODO(Tian) explain better how we store them). + // We record the number of optimized and input loops to separate these three // group of operands out. static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; } @@ -143,32 +141,8 @@ def KrnlIterateOp : Op { return num_optimized_loops; } - static StringRef getNumInputLoopsAttrName() { return "num_input_loops"; } - - int64_t getNumInputLoops() { - auto num_loops = - getAttrOfType( - getNumInputLoopsAttrName()) - .getValue() - .getSExtValue(); - return num_loops; - } - - // Constant bounds are stored here as a list attribute. - static StringRef getConstantBoundsAttrName() { return "constant_bounds"; } - - // Store type of each bound as three types: - // - 0 = constant attribute. - // - 1 = operand type. - // - 2 = affine maps (TODO). - static StringRef getBoundTypesAttrName() { return "bound_types"; } - - // Get dynamic attribute name for the i-th lower and upper bound. - static std::string getBoundAttrName(int64_t i, bool is_ub) { - std::string bound_type = is_ub ? "_ub" : "_lb"; - std::string bound_idx = std::to_string(i); - return "__bound_" + bound_idx + bound_type; - } + // Get name of the attribute for storing bound represented using affine maps. + static StringRef getBoundsAttrName() { return "bounds"; } }]; let printer = [{ return ::print(p, *this); }]; diff --git a/src/compiler/dialect/krnl/parser_helper.cpp b/src/compiler/dialect/krnl/parser_helper.cpp deleted file mode 100644 index 814b0fa..0000000 --- a/src/compiler/dialect/krnl/parser_helper.cpp +++ /dev/null @@ -1,52 +0,0 @@ -//===------------------ parser_helper.cpp - MLIR Operations ---------------===// -// -// Copyright 2019 The IBM Research Authors. -// -// ============================================================================= -// -//===----------------------------------------------------------------------===// - -#include "parser_helper.hpp" - -#include "src/compiler/dialect/krnl/krnl_ops.hpp" - -namespace onnf { - -mlir::ParseResult KrnlDialectOperandParser::ParseOptionalOperand( - mlir::Type operand_type, mlir::Value*& operand) { - // If operand queue is empty, parse more operands and cache them. - if (_operand_ref_queue.empty()) { - // Parse operand types: - llvm::SmallVector operand_refs; - _parser.parseOperandList(operand_refs); - - // Record operands: - for (auto& operand_ref : operand_refs) - _operand_ref_queue.emplace(operand_ref); - } - - // If we parsed some operand reference(s), resolve the ref to an operand: - if (!_operand_ref_queue.empty()) { - auto operand_ref = _operand_ref_queue.front(); - _operand_ref_queue.pop(); - - llvm::SmallVector operands; - _parser.resolveOperand(operand_ref, operand_type, operands); - operand = operands.front(); - return mlir::success(); - } else { - operand = nullptr; - return mlir::failure(); - } -} - -mlir::ParseResult KrnlDialectOperandParser::ParseOperand( - mlir::Type operand_type, mlir::Value*& operand) { - ParseOptionalOperand(operand_type, operand); - if (operand == nullptr) - return _parser.emitError( - _parser.getCurrentLocation(), "Expecting an operand."); - return mlir::success(); -} - -} // namespace onnf diff --git a/src/compiler/dialect/krnl/parser_helper.hpp b/src/compiler/dialect/krnl/parser_helper.hpp deleted file mode 100644 index e1928fd..0000000 --- a/src/compiler/dialect/krnl/parser_helper.hpp +++ /dev/null @@ -1,46 +0,0 @@ -//===------------------ parser_helper.hpp - MLIR Operations ---------------===// -// -// Copyright 2019 The IBM Research Authors. -// -// ============================================================================= -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include - -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/StandardTypes.h" - -namespace onnf { - -class KrnlDialectOperandParser { - public: - KrnlDialectOperandParser(mlir::OpAsmParser& parser) - : _parser(parser), _builder(parser.getBuilder()){}; - - // Parse an optional operand. - mlir::ParseResult ParseOptionalOperand( - mlir::Type operand_type, mlir::Value*& operand); - - // Parse a required operand. - mlir::ParseResult ParseOperand( - mlir::Type operand_type, mlir::Value*& operand); - - // Do we have more operands to parse? - bool has_operand_left() { return !_operand_ref_queue.empty(); } - - private: - mlir::OpAsmParser& _parser; - - mlir::Builder& _builder; - - // A queue storing the parsed SSA id references. - std::queue _operand_ref_queue; -}; - -} // namespace onnf diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp index 7401037..430818b 100644 --- a/src/compiler/pass/lower_frontend_to_krnl.cpp +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -16,6 +16,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "src/compiler/dialect/krnl/krnl_helper.hpp" #include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp" @@ -43,13 +44,12 @@ static MemRefType convertTensorToMemRef(TensorType type) { } /// Insert an allocation and deallocation for the given MemRefType. -static Value* insertAllocAndDealloc( - MemRefType type, Location loc, PatternRewriter& rewriter, - Value *oldMemRef = nullptr) { +static Value* insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter& rewriter, Value* oldMemRef = nullptr) { // Put together alloc operands for any dynamic dimensions of the memref. AllocOp alloc; if (oldMemRef) { - SmallVector allocOperands; + SmallVector allocOperands; auto memRefShape = type.getShape(); for (int i = 0; i < memRefShape.size(); ++i) if (memRefShape[i] < 0) @@ -95,7 +95,7 @@ struct ONNXAddOpLowering : public ConversionPattern { // dimensions with the result at this pre-optimization phase. // TODO: verify that dimensions match. // TODO: can the dimension of the result differ after optimizations? - Value *alloc; + Value* alloc; if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter); else @@ -122,33 +122,22 @@ struct ONNXAddOpLowering : public ConversionPattern { } Block& optimizationBlock = optimizedLoopsOp.region().front(); + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest. // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape // to KrnlIterateOp instead. - SmallVector operandBounds; - SmallVector constBounds; - SmallVector boundTypes; for (int i = 0; i < rank; ++i) { if (memRefShape[i] < 0) { - // This is a dynamic value, hence use operands. - // Lower bound - constBounds.push_back(0); - boundTypes.push_back(0); - // Upper bound - operandBounds.push_back( + pack.pushConstantBound(0); + pack.pushOperandBound( rewriter.create(loc, operands[0], i).getResult()); - boundTypes.push_back(1); } else { - // Lower bound - constBounds.push_back(0); - boundTypes.push_back(0); - // Upper bound - constBounds.push_back(memRefShape[i]); - boundTypes.push_back(0); + pack.pushConstantBound(0); + pack.pushConstantBound(memRefShape[i]); } } - auto iterateOp = rewriter.create(loc, originalLoops, - optimizedLoops, operandBounds, constBounds, boundTypes); + + auto iterateOp = rewriter.create(loc, pack); Block& iterationBlock = iterateOp.bodyRegion().front(); // Now perform the insertions into the body of the @@ -169,14 +158,12 @@ struct ONNXAddOpLowering : public ConversionPattern { SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); - auto loadedFirstVal = - rewriter.create(loc, operands[0], loopIVs); - auto loadedSecondVal = - rewriter.create(loc, operands[1], loopIVs); + auto loadedFirstVal = rewriter.create(loc, operands[0], loopIVs); + auto loadedSecondVal = rewriter.create(loc, operands[1], loopIVs); // TODO: Choose type of the Add for now use the Float Add. - auto addOpResult = rewriter.create( - loc, loadedFirstVal, loadedSecondVal); + auto addOpResult = + rewriter.create(loc, loadedFirstVal, loadedSecondVal); // Store result in the resulting array. rewriter.create(loc, addOpResult, alloc, loopIVs); @@ -209,8 +196,8 @@ struct TensorTypeConverter : public TypeConverter { /// inputs. Once unranked results can be handled gracefully this /// override needs to be removed in favour of the original MLIR one.] bool isSignatureLegal(FunctionType funcType) { - return llvm::all_of(funcType.getInputs(), - [this](Type type) { return isLegal(type); }); + return llvm::all_of( + funcType.getInputs(), [this](Type type) { return isLegal(type); }); } }; @@ -272,8 +259,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. - if (failed(applyPartialConversion( - module, target, patterns))) + if (failed(applyPartialConversion(module, target, patterns))) signalPassFailure(); } diff --git a/src/compiler/pass/passes.hpp b/src/compiler/pass/passes.hpp index a268021..89d08a1 100644 --- a/src/compiler/pass/passes.hpp +++ b/src/compiler/pass/passes.hpp @@ -17,9 +17,10 @@ class Pass; std::unique_ptr createShapeInferencePass(); -/// Pass for lowering frontend dialects to Krnl IR dialect. -std::unique_ptr createLowerToKrnlPass(); +/// Add pass for lowering to Krnl IR. +std::unique_ptr createLowerToKrnlPass(); -// TODO: Add pass for lowering to LLVM IR. +/// Pass for lowering frontend dialects to Krnl IR dialect. +std::unique_ptr createLowerKrnlPass(); } // end namespace mlir diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp index fc3ed1f..27463ab 100644 --- a/src/compiler/pass/shape_inference_pass.cpp +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -75,7 +75,7 @@ class ShapeInferencePass : public mlir::FunctionPass { if (auto terminator_op = f.getBody().back().getTerminator()) { auto results = terminator_op->getOperandTypes(); f.setType(FunctionType::get(f.getType().getInputs(), - std::vector(results.begin(), results.end()), f.getContext())); + std::vector(results.begin(), results.end()), f.getContext())); } } diff --git a/src/compiler/tool/onnf_opt/CMakeLists.txt b/src/compiler/tool/onnf_opt/CMakeLists.txt index 669e999..7cde1f2 100644 --- a/src/compiler/tool/onnf_opt/CMakeLists.txt +++ b/src/compiler/tool/onnf_opt/CMakeLists.txt @@ -6,8 +6,7 @@ target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT}) target_link_libraries(onnf-opt compiler ${MLIRLibs}) whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs}) -whole_archive_link_onnf(onnf-opt onnf_lower_frontend) -whole_archive_link_onnf(onnf-opt onnf_shape_inference) +whole_archive_link_onnf(onnf-opt onnf_transform onnf_lower_frontend onnf_shape_inference) # TODO: need to investigate how to whole-archive link compiler pass to onnf-opt. target_link_libraries(onnf-opt compiler) diff --git a/src/compiler/transform/CMakeLists.txt b/src/compiler/transform/CMakeLists.txt new file mode 100644 index 0000000..65f8130 --- /dev/null +++ b/src/compiler/transform/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(onnf_transform lower_krnl.cpp) + +target_include_directories(onnf_transform + PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} + ${ONNF_SRC_ROOT}) +target_link_libraries(onnf_transform ${MLIRLibs}) +add_dependencies(onnf_transform gen_krnl_ops) diff --git a/src/compiler/transform/lower_krnl.cpp b/src/compiler/transform/lower_krnl.cpp new file mode 100644 index 0000000..921124b --- /dev/null +++ b/src/compiler/transform/lower_krnl.cpp @@ -0,0 +1,161 @@ +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/compiler/dialect/krnl/krnl_ops.hpp" +#include "src/compiler/pass/passes.hpp" + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// Krnl to Affine Rewrite Patterns: KrnlIterate operation. +//===----------------------------------------------------------------------===// + +struct KrnlIterateOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite( + KrnlIterateOp iterateOp, PatternRewriter& rewriter) const override { + auto boundMapAttrs = + iterateOp.getAttrOfType(KrnlIterateOp::getBoundsAttrName()) + .getValue(); + auto operandItr = + iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops(); + SmallVector nestedForOps; + for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) { + // Consume input loop operand, currently do not do anything with it. + operandItr++; + + // Organize operands into lower/upper bounds in affine.for ready formats. + SmallVector lbOperands, ubOperands; + AffineMap lbMap, ubMap; + for (int boundType = 0; boundType < 2; boundType++) { + auto& operands = boundType == 0 ? lbOperands : ubOperands; + auto& map = boundType == 0 ? lbMap : ubMap; + map = boundMapAttrs[boundIdx + boundType] + .cast() + .getValue(); + operands.insert( + operands.end(), operandItr, operandItr + map.getNumInputs()); + std::advance(operandItr, map.getNumInputs()); + } + + nestedForOps.emplace_back(rewriter.create( + iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap)); + rewriter.setInsertionPoint(nestedForOps.back().getBody(), + nestedForOps.back().getBody()->begin()); + } + + // Replace induction variable references from those introduced by a + // single krnl.iterate to those introduced by multiple affine.for + // operations. + for (size_t i = 0; i < nestedForOps.size() - 1; i++) { + auto iterateIV = iterateOp.bodyRegion().front().getArgument(0); + auto forIV = nestedForOps[i].getBody()->getArgument(0); + iterateIV->replaceAllUsesWith(forIV); + iterateOp.bodyRegion().front().eraseArgument(0); + } + + // Pop krnl.iterate body region block arguments, leave the last one + // for convenience (it'll be taken care of by region inlining). + while (iterateOp.bodyRegion().front().getNumArguments() > 1) + iterateOp.bodyRegion().front().eraseArgument(0); + + // Transfer krnl.iterate region to innermost for op. + auto innermostForOp = nestedForOps.back(); + innermostForOp.region().getBlocks().clear(); + rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(), + innermostForOp.region().end()); + + rewriter.eraseOp(iterateOp); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// Krnl to Affine Rewrite Patterns: KrnlTerminator operation. +//===----------------------------------------------------------------------===// + +class KrnlTerminatorLowering : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite( + KrnlTerminatorOp op, PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// Krnl to Affine Rewrite Patterns: KrnlDefineLoops operation. +//===----------------------------------------------------------------------===// + +class KrnlDefineLoopsLowering : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite( + KrnlDefineLoopsOp op, PatternRewriter& rewriter) const override { + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation. +//===----------------------------------------------------------------------===// + +class KrnlOptimizeLoopsLowering : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite( + KrnlOptimizeLoopsOp op, PatternRewriter& rewriter) const override { + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// KrnlToAffineLoweringPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to affine loops of the krnl dialect operations. +/// At this stage the dialect will contain standard operations as well like +/// add and multiply, this pass will leave these operations intact. +namespace { +struct KrnlToAffineLoweringPass + : public FunctionPass { + void runOnFunction() final; +}; +} // end anonymous namespace. + +void KrnlToAffineLoweringPass::runOnFunction() { + auto function = getFunction(); + + ConversionTarget target(getContext()); + + target.addLegalDialect(); + // We expect IR to be free of Krnl Dialect Ops. + target.addIllegalDialect(); + + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + if (failed(applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +} // namespace + +std::unique_ptr mlir::createLowerKrnlPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "lower-krnl", "Lower Krnl dialect."); \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 61fd846..7306aa1 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -28,6 +28,7 @@ #include +#include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/Support/FileUtilities.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SourceMgr.h" @@ -38,6 +39,8 @@ #include "src/compiler/pass/passes.hpp" #include "mlir/Analysis/Verifier.h" +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/MLIRContext.h" @@ -125,7 +128,20 @@ int main(int ac, char* av[]) { pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createLowerToKrnlPass()); + pm.addPass(mlir::createLowerKrnlPass()); + pm.addPass(mlir::createLowerAffinePass()); + pm.addPass(mlir::createLowerToCFGPass()); + pm.addPass(mlir::createLowerToLLVMPass()); + pm.addPass(mlir::createCanonicalizerPass()); pm.run(*module); + // Write LLVM bitcode to disk. + std::error_code EC; + llvm::raw_fd_ostream moduleBitcodeStream( + "model.bc", EC, llvm::sys::fs::F_None); + llvm::WriteBitcodeToFile( + *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); + moduleBitcodeStream.flush(); + return 0; } diff --git a/test/mlir/krnl/ops.mlir b/test/mlir/krnl/ops.mlir new file mode 100644 index 0000000..0300a11 --- /dev/null +++ b/test/mlir/krnl/ops.mlir @@ -0,0 +1,75 @@ +// RUN: onnf-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s +// RUN: onnf-opt %s | FileCheck %s + +// GENERIC-DAG: #{{.*}} = () -> (0) +// GENERIC-DAG: #{{.*}} = () -> (10) +// GENERIC-DAG: #{{.*}} = () -> (1) +// GENERIC-DAG: #{{.*}} = () -> (11) +// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 - d1) +// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 + d1) + +func @simple_iterate(%N : index) { + %ii, %ij, %ik = krnl.define_loops 3 + %oi, %oj, %ok = krnl.optimize_loops { + krnl.return_loops %ii, %ij, %ik + } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + + // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { + // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): + // GENERIC-NEXT: "krnl.terminate"() : () -> () + // GENERIC-NEXT: bounds = [#{{.*}}, #{{.*}}, #{{.*}}, #{{.*}}] + + // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 1 to 11) { + krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 1 to 11) { + + } + + // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { + // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): + // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) { + krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 0 to 10) { + // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}) ( { + // GENERIC-NEXT: ^bb0(%{{.*}}: index): + // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10) { + krnl.iterate(%ok) with (%ik -> %k = 0 to 10) { + + } + } + + // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { + // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): + // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to %{{.*}}, %{{.*}} -> %{{.*}} = 0 to 10) { + krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to %N, %ij -> %j = 0 to 10) { + + } + + return +} + +func @affine_map_bound(%N : index) { + %ii, %ij, %ik = krnl.define_loops 3 + %oi, %oj, %ok = krnl.optimize_loops { + krnl.return_loops %ii, %ij, %ik + } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + + // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { + // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): + // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) { + krnl.iterate(%oi, %oj) with (%ii -> %i = ()->(0)() to ()->(10)(), %ij -> %j = 0 to 10) { + // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { + // GENERIC-NEXT: ^bb0(%{{.*}}: index): + // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = #{{.*}}(%{{.*}}, %{{.*}}) to #{{.*}}(%{{.*}}, %{{.*}})) { + krnl.iterate(%ok) with (%ik -> %k = (d0, d1)->(d0 - d1)(%i, %j) to (d0, d1)->(d0 + d1)(%i, %j)) { + + } + + // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { + // GENERIC-NEXT: ^bb0(%{{.*}}: index): + // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = max #map{{.*}}(%{{.*}}, %{{.*}}) to min #map{{.*}}(%{{.*}}, %{{.*}})[%{{.*}}]) { + krnl.iterate(%ok) with (%ik -> %k = max (d0, d1)->(d0 - d1, 0)(%i, %j) to min (d0, d1)[s0]->(d0 + d1, s0)(%i, %j)[%N]) { + + } + } + + return +} \ No newline at end of file