[MLIR] Refactor Krnl Dialect and Krnl Dialect Lowering (#375)
* Store bounds as affine map attributes & check in test cases with generic printer * Upgrading MLIR MLIR is outdated on Buildbot, rebuilding a newer version. * work with new version of mlir * check-in parser tests * custom printer * nit * bug fix * enable custom asm printer test * enable custom asm printer test * more consistent variable naming * test max/min * variable naming scheme change to MLIR style * can lower krnl to llvm * kernel -> llvm * comments * bug fix * try fixing ci * fix ci * deactivate model test * fix lit test * nit * fix z buildbot
This commit is contained in:
		
							parent
							
								
									652ce4b7d4
								
							
						
					
					
						commit
						b2a1103915
					
				|  | @ -18,6 +18,8 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) | |||
| #TODO(eventually enable the following) | ||||
| #set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) | ||||
| 
 | ||||
| include(MLIR.cmake) | ||||
| 
 | ||||
| add_subdirectory(third_party/onnx) | ||||
| add_subdirectory(third_party/benchmark) | ||||
| add_subdirectory(third_party/pybind11) | ||||
|  | @ -41,7 +43,6 @@ if(Boost_FOUND) | |||
|   include_directories(${Boost_INCLUDE_DIRS}) | ||||
| endif() | ||||
| 
 | ||||
| include(MLIR.cmake) | ||||
| add_subdirectory(src/builder) | ||||
| add_subdirectory(src/compiler) | ||||
| add_subdirectory(src) | ||||
|  |  | |||
							
								
								
									
										92
									
								
								MLIR.cmake
								
								
								
								
							
							
						
						
									
										92
									
								
								MLIR.cmake
								
								
								
								
							|  | @ -5,7 +5,7 @@ if(DEFINED ENV{LLVM_SRC}) | |||
|     message(STATUS "LLVM_SRC " ${LLVM_SRC}) | ||||
|   else() | ||||
|     message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: " | ||||
|                         ${LLVM_SRC}) | ||||
|             ${LLVM_SRC}) | ||||
|   endif() | ||||
| else() | ||||
|   message(FATAL_ERROR "env variable LLVM_SRC not set") | ||||
|  | @ -18,7 +18,7 @@ if(DEFINED ENV{LLVM_BUILD}) | |||
|     message(STATUS "LLVM_BUILD " ${LLVM_BUILD}) | ||||
|   else() | ||||
|     message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: " | ||||
|                         ${LLVM_BUILD}) | ||||
|             ${LLVM_BUILD}) | ||||
|   endif() | ||||
| else() | ||||
|   message(FATAL_ERROR "env variable LLVM_BUILD not set") | ||||
|  | @ -39,9 +39,9 @@ set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir) | |||
| set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir) | ||||
| 
 | ||||
| set( | ||||
|   MLIR_INCLUDE_PATHS | ||||
|   ${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH} | ||||
|   ) | ||||
|         MLIR_INCLUDE_PATHS | ||||
|         ${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH} | ||||
| ) | ||||
| include_directories(${MLIR_INCLUDE_PATHS}) | ||||
| 
 | ||||
| # Threading libraries required due to parallel pass execution. | ||||
|  | @ -49,9 +49,9 @@ find_package(Threads REQUIRED) | |||
| 
 | ||||
| function(find_mlir_lib lib) | ||||
|   find_library(${lib} | ||||
|                NAMES ${lib} | ||||
|                PATHS ${LLVM_PROJECT_LIB} | ||||
|                NO_DEFAULT_PATH) | ||||
|           NAMES ${lib} | ||||
|           PATHS ${LLVM_PROJECT_LIB} | ||||
|           NO_DEFAULT_PATH) | ||||
| endfunction(find_mlir_lib) | ||||
| 
 | ||||
| find_mlir_lib(MLIRAffineOps) | ||||
|  | @ -70,6 +70,10 @@ find_mlir_lib(MLIRTransforms) | |||
| find_mlir_lib(MLIRTransformUtils) | ||||
| find_mlir_lib(MLIRSupport) | ||||
| find_mlir_lib(MLIROptMain) | ||||
| find_mlir_lib(MLIRTargetLLVMIRModuleTranslation) | ||||
| find_mlir_lib(MLIRTargetLLVMIR) | ||||
| find_mlir_lib(MLIRTransformUtils) | ||||
| find_mlir_lib(MLIRTranslation) | ||||
| find_mlir_lib(MLIRVectorOps) | ||||
| 
 | ||||
| find_mlir_lib(LLVMCore) | ||||
|  | @ -80,46 +84,52 @@ find_mlir_lib(LLVMRemarks) | |||
| find_mlir_lib(LLVMIRReader) | ||||
| find_mlir_lib(LLVMTransformUtils) | ||||
| find_mlir_lib(LLVMBitstreamReader) | ||||
| find_mlir_lib(LLVMAnalysis) | ||||
| find_mlir_lib(LLVMBitWriter) | ||||
| find_mlir_lib(LLVMBitReader) | ||||
| find_mlir_lib(LLVMMC) | ||||
| find_mlir_lib(LLVMMCParser) | ||||
| find_mlir_lib(LLVMObject) | ||||
| find_mlir_lib(LLVMProfileData) | ||||
| find_mlir_lib(LLVMDemangle) | ||||
| 
 | ||||
| 
 | ||||
| set(MLIRLibsOnce | ||||
|         LLVMAnalysis | ||||
|         LLVMAsmParser | ||||
|         LLVMBinaryFormat | ||||
|         LLVMBitReader | ||||
|         LLVMBitstreamReader | ||||
|         LLVMBitWriter | ||||
|         LLVMCore | ||||
|         LLVMIRReader | ||||
|         LLVMMC | ||||
|         LLVMMCParser | ||||
|         LLVMObject | ||||
|         LLVMRemarks | ||||
|         LLVMSupport | ||||
|         LLVMTransformUtils | ||||
|         LLVMProfileData | ||||
|         LLVMDemangle | ||||
|         MLIRAffineOps | ||||
|         MLIRAffineToStandard | ||||
|         MLIRAnalysis | ||||
|         MLIRExecutionEngine | ||||
|         MLIRIR | ||||
|         MLIRLLVMIR | ||||
|         MLIRLoopOps | ||||
|         MLIRLoopToStandard | ||||
|         MLIROptMain | ||||
|         MLIRParser | ||||
|         MLIRPass | ||||
|         MLIRStandardOps | ||||
|         MLIRStandardToLLVM | ||||
|         MLIRSupport | ||||
|         MLIRTargetLLVMIR | ||||
|         MLIRTransforms | ||||
|         MLIRAffineOps | ||||
|         MLIRAffineToStandard | ||||
|         MLIRAnalysis | ||||
|         MLIRExecutionEngine | ||||
|         MLIRIR | ||||
|         MLIRLLVMIR | ||||
|         MLIRLoopToStandard | ||||
|         MLIRParser | ||||
|         MLIRPass | ||||
|         MLIRStandardOps | ||||
|         MLIRStandardToLLVM | ||||
|         MLIRTargetLLVMIR | ||||
|         MLIRTargetLLVMIRModuleTranslation | ||||
|         MLIRTransforms | ||||
|         MLIRTransformUtils | ||||
|         MLIRLoopOps | ||||
|         MLIRSupport | ||||
|         MLIROptMain | ||||
|         LLVMCore | ||||
|         LLVMSupport | ||||
|         LLVMAsmParser | ||||
|         LLVMIRReader | ||||
|         LLVMTransformUtils | ||||
|         LLVMBinaryFormat | ||||
|         LLVMRemarks | ||||
|         LLVMBitstreamReader) | ||||
|         MLIRTranslation) | ||||
| 
 | ||||
| set(MLIRLibs | ||||
|         ${MLIRLibsOnce} | ||||
|  | @ -142,7 +152,7 @@ function(whole_archive_link target lib_dir) | |||
|     set(link_flags "${link_flags} -L${lib_dir} ") | ||||
|     foreach(LIB ${ARGN}) | ||||
|       string(CONCAT link_flags ${link_flags} | ||||
|                     "-Wl,-force_load ${lib_dir}/lib${LIB}.a ") | ||||
|               "-Wl,-force_load ${lib_dir}/lib${LIB}.a ") | ||||
|     endforeach(LIB) | ||||
|   elseif(MSVC) | ||||
|     foreach(LIB ${ARGN}) | ||||
|  | @ -170,20 +180,20 @@ function(whole_archive_link_onnf target) | |||
| endfunction(whole_archive_link_onnf) | ||||
| 
 | ||||
| set(LLVM_CMAKE_DIR | ||||
|     "${LLVM_BUILD}/lib/cmake/llvm" | ||||
|     CACHE PATH "Path to LLVM cmake modules") | ||||
|         "${LLVM_BUILD}/lib/cmake/llvm" | ||||
|         CACHE PATH "Path to LLVM cmake modules") | ||||
| list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") | ||||
| include(AddLLVM) | ||||
| include(TableGen) | ||||
| 
 | ||||
| function(onnf_tablegen ofn) | ||||
|   tablegen(MLIR | ||||
|            ${ARGV} | ||||
|            "-I${MLIR_SRC_INCLUDE_PATH}" | ||||
|            "-I${MLIR_BIN_INCLUDE_PATH}") | ||||
|           ${ARGV} | ||||
|           "-I${MLIR_SRC_INCLUDE_PATH}" | ||||
|           "-I${MLIR_BIN_INCLUDE_PATH}") | ||||
|   set(TABLEGEN_OUTPUT | ||||
|       ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} | ||||
|       PARENT_SCOPE) | ||||
|           ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} | ||||
|           PARENT_SCOPE) | ||||
| endfunction() | ||||
| 
 | ||||
| # Import the pre-built mlir TableGen as an imported exetuable. It is required by | ||||
|  | @ -191,5 +201,5 @@ endfunction() | |||
| # table gen utility itself can be detected and cause re-compilation of .td file. | ||||
| add_executable(mlir-tblgen IMPORTED) | ||||
| set_property(TARGET mlir-tblgen | ||||
|              PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen) | ||||
|         PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen) | ||||
| set(MLIR_TABLEGEN_EXE mlir-tblgen) | ||||
|  |  | |||
|  | @ -2,8 +2,9 @@ add_executable(onnf main.cpp) | |||
| 
 | ||||
| target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES}) | ||||
| whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs}) | ||||
| whole_archive_link_onnf(onnf onnf_transform) | ||||
| 
 | ||||
| target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR}) | ||||
| target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR}) | ||||
| 
 | ||||
| install(TARGETS onnf DESTINATION bin) | ||||
| install(TARGETS onnf DESTINATION bin) | ||||
|  |  | |||
|  | @ -6,8 +6,8 @@ add_library( | |||
|         dialect/krnl/krnl_types.hpp | ||||
|         dialect/onnx/onnx_ops.cpp | ||||
|         dialect/onnx/onnx_ops.hpp | ||||
|         dialect/krnl/parser_helper.cpp | ||||
|         dialect/krnl/parser_helper.hpp | ||||
|         dialect/krnl/krnl_helper.cpp | ||||
|         dialect/krnl/krnl_helper.hpp | ||||
|         pass/shape_inference_pass.cpp | ||||
|         pass/shape_inference_interface.hpp | ||||
|         dialect/onnx/onnxop.inc | ||||
|  | @ -82,4 +82,5 @@ target_include_directories(onnf_lower_frontend | |||
| target_link_libraries(onnf_lower_frontend ${MLIRLibs}) | ||||
| add_dependencies(onnf_lower_frontend gen_krnl_ops) | ||||
| 
 | ||||
| add_subdirectory(transform) | ||||
| add_subdirectory(tool) | ||||
|  |  | |||
|  | @ -0,0 +1,139 @@ | |||
| #include "mlir/Dialect/AffineOps/AffineOps.h" | ||||
| #include "mlir/IR/AffineExpr.h" | ||||
| 
 | ||||
| #include "src/compiler/dialect/krnl/krnl_ops.hpp" | ||||
| 
 | ||||
| #include "krnl_helper.hpp" | ||||
| 
 | ||||
| namespace onnf { | ||||
| 
 | ||||
| using namespace mlir; | ||||
| 
 | ||||
| ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ||||
|     const Type& operandType, Value*& operand) { | ||||
|   // If operand queue is empty, parse more operands and cache them.
 | ||||
|   if (_operandRefQueue.empty()) { | ||||
|     // Parse operand types:
 | ||||
|     llvm::SmallVector<OpAsmParser::OperandType, 2> operand_refs; | ||||
|     _parser.parseOperandList(operand_refs); | ||||
| 
 | ||||
|     // Record operands:
 | ||||
|     for (auto& operand_ref : operand_refs) | ||||
|       _operandRefQueue.emplace(operand_ref); | ||||
|   } | ||||
| 
 | ||||
|   // If we parsed some operand reference(s), resolve the ref to an operand:
 | ||||
|   if (!_operandRefQueue.empty()) { | ||||
|     auto operand_ref = _operandRefQueue.front(); | ||||
|     _operandRefQueue.pop(); | ||||
| 
 | ||||
|     llvm::SmallVector<Value*, 1> operands; | ||||
|     _parser.resolveOperand(operand_ref, operandType, operands); | ||||
|     operand = operands.front(); | ||||
|     return success(); | ||||
|   } else { | ||||
|     operand = nullptr; | ||||
|     return failure(); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ||||
|     const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) { | ||||
|   Value* operand = nullptr; | ||||
|   if (ParseOptionalOperand(operandType, operand)) | ||||
|     return failure(); | ||||
| 
 | ||||
|   operandList.emplace_back(operand); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| ParseResult KrnlDialectOperandParser::ParseOperand( | ||||
|     const Type& operandType, Value*& operand) { | ||||
|   if (ParseOptionalOperand(operandType, operand)) | ||||
|     return _parser.emitError( | ||||
|         _parser.getCurrentLocation(), "Expecting an operand."); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| ParseResult KrnlDialectOperandParser::ParseOperand( | ||||
|     const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) { | ||||
|   if (ParseOptionalOperand(operandType, operandList)) | ||||
|     return _parser.emitError( | ||||
|         _parser.getCurrentLocation(), "Expecting an operand."); | ||||
| 
 | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims, | ||||
|     unsigned numSymbols, OpAsmPrinter& p) { | ||||
|   p << '('; | ||||
|   p.printOperands(begin, begin + numDims); | ||||
|   p << ')'; | ||||
| 
 | ||||
|   if (numSymbols) { | ||||
|     p << '['; | ||||
|     p.printOperands(begin + numDims, begin + numDims + numSymbols); | ||||
|     p << ']'; | ||||
|   } | ||||
| 
 | ||||
|   begin = std::next(begin, numDims + numSymbols); | ||||
| } | ||||
| 
 | ||||
| void printBound(AffineMapAttr boundMap, | ||||
|     Operation::operand_iterator& boundOperandsBeg, const char* prefix, | ||||
|     OpAsmPrinter& p) { | ||||
|   AffineMap map = boundMap.getValue(); | ||||
| 
 | ||||
|   // Check if this bound should be printed using custom assembly form.
 | ||||
|   // The decision to restrict printing custom assembly form to trivial cases
 | ||||
|   // comes from the will to roundtrip MLIR binary -> text -> binary in a
 | ||||
|   // lossless way.
 | ||||
|   // Therefore, custom assembly form parsing and printing is only supported for
 | ||||
|   // zero-operand constant maps and single symbol operand identity maps.
 | ||||
|   if (map.getNumResults() == 1) { | ||||
|     AffineExpr expr = map.getResult(0); | ||||
| 
 | ||||
|     // Print constant bound.
 | ||||
|     if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { | ||||
|       if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { | ||||
|         p << constExpr.getValue(); | ||||
|         return; | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     // Print bound that consists of a single SSA symbol if the map is over a
 | ||||
|     // single symbol.
 | ||||
|     if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { | ||||
|       if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) { | ||||
|         p.printOperand(*(boundOperandsBeg++)); | ||||
|         return; | ||||
|       } | ||||
|     } | ||||
|   } else { | ||||
|     // Map has multiple results. Print 'min' or 'max' prefix.
 | ||||
|     p << prefix << ' '; | ||||
|   } | ||||
| 
 | ||||
|   // Print the map and its operands.
 | ||||
|   p << boundMap; | ||||
|   printDimAndSymbolList( | ||||
|       boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p); | ||||
| } | ||||
| }  // namespace onnf
 | ||||
| 
 | ||||
| namespace mlir { | ||||
| void KrnlIterateOperandPack::pushConstantBound(int64_t bound) { | ||||
|   if (boundMaps.size() % 2 == 0) | ||||
|     _operands.emplace_back(inputLoops[boundMaps.size() / 2]); | ||||
|   AffineMap map = builder.getConstantAffineMap(bound); | ||||
|   boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||
| } | ||||
| 
 | ||||
| void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) { | ||||
|   if (boundMaps.size() % 2 == 0) | ||||
|     _operands.emplace_back(inputLoops[boundMaps.size() / 2]); | ||||
|   AffineMap map = builder.getSymbolIdentityMap(); | ||||
|   boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||
|   _operands.emplace_back(operand); | ||||
| } | ||||
| }  // namespace mlir
 | ||||
|  | @ -0,0 +1,102 @@ | |||
| #pragma once | ||||
| 
 | ||||
| #include <queue> | ||||
| 
 | ||||
| #include "mlir/IR/Attributes.h" | ||||
| #include "mlir/IR/Builders.h" | ||||
| #include "mlir/IR/Dialect.h" | ||||
| #include "mlir/IR/OpDefinition.h" | ||||
| #include "mlir/IR/OpImplementation.h" | ||||
| #include "mlir/IR/StandardTypes.h" | ||||
| 
 | ||||
| namespace onnf { | ||||
| 
 | ||||
| class KrnlDialectOperandParser { | ||||
|  public: | ||||
|   explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser) | ||||
|       : _parser(parser), _builder(parser.getBuilder()){}; | ||||
| 
 | ||||
|   // Parse an optional operand.
 | ||||
|   mlir::ParseResult ParseOptionalOperand( | ||||
|       const mlir::Type& operandType, mlir::Value*& operand); | ||||
| 
 | ||||
|   // Parse an optional operand and push it to an operand list.
 | ||||
|   mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType, | ||||
|       llvm::SmallVectorImpl<mlir::Value*>& operandList); | ||||
| 
 | ||||
|   // Parse a required operand.
 | ||||
|   mlir::ParseResult ParseOperand( | ||||
|       const mlir::Type& operandType, mlir::Value*& operand); | ||||
| 
 | ||||
|   // Parse a required operand and push it to an operand list.
 | ||||
|   mlir::ParseResult ParseOperand(const mlir::Type& operandType, | ||||
|       llvm::SmallVectorImpl<mlir::Value*>& operandList); | ||||
| 
 | ||||
|   // Do we have more operands to parse?
 | ||||
|   bool hasOperandLeft() { return !_operandRefQueue.empty(); } | ||||
| 
 | ||||
|  private: | ||||
|   mlir::OpAsmParser& _parser; | ||||
| 
 | ||||
|   mlir::Builder& _builder; | ||||
| 
 | ||||
|   // A queue storing the parsed SSA id references.
 | ||||
|   std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue; | ||||
| }; | ||||
| 
 | ||||
| // Adapted from:
 | ||||
| // https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
 | ||||
| // Main difference is that it advances the iterator `begin` as it consumes
 | ||||
| // dimension and symbol operands.
 | ||||
| void printDimAndSymbolList(mlir::Operation::operand_iterator& begin, | ||||
|     unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p); | ||||
| 
 | ||||
| // Adapted from:
 | ||||
| // https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
 | ||||
| // Main difference is that it advances the iterator `boundOperandsBeg` as it
 | ||||
| // prints bound.
 | ||||
| void printBound(mlir::AffineMapAttr boundMap, | ||||
|     mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix, | ||||
|     mlir::OpAsmPrinter& p); | ||||
| }  // namespace onnf
 | ||||
| 
 | ||||
| namespace mlir { | ||||
| 
 | ||||
| struct KrnlIterateOperandPack { | ||||
|   KrnlIterateOperandPack(mlir::Builder& builder, | ||||
|       llvm::ArrayRef<mlir::Value*> inputLoops, | ||||
|       llvm::ArrayRef<mlir::Value*> optimizedLoops) | ||||
|       : builder(builder), | ||||
|         inputLoops(inputLoops), | ||||
|         optimizedLoops(optimizedLoops) { | ||||
|     _operands.insert( | ||||
|         _operands.end(), optimizedLoops.begin(), optimizedLoops.end()); | ||||
|   } | ||||
| 
 | ||||
|   void pushConstantBound(int64_t bound); | ||||
| 
 | ||||
|   void pushOperandBound(mlir::Value* operand); | ||||
| 
 | ||||
|   llvm::SmallVector<mlir::Value*, 8> getOperands() const { return _operands; } | ||||
| 
 | ||||
|   mlir::ArrayAttr getAttributes() const { | ||||
|     return builder.getArrayAttr(boundMaps); | ||||
|   } | ||||
| 
 | ||||
|   size_t getNumOptimizedLoops() const { return optimizedLoops.size(); } | ||||
| 
 | ||||
|   size_t getNumInputLoops() const { return inputLoops.size(); } | ||||
| 
 | ||||
|  private: | ||||
|   int _boundIdx = 0; | ||||
| 
 | ||||
|   llvm::SmallVector<mlir::Value*, 8> _operands; | ||||
| 
 | ||||
|   llvm::SmallVector<mlir::Attribute, 8> boundMaps; | ||||
| 
 | ||||
|   llvm::ArrayRef<mlir::Value*> inputLoops, optimizedLoops; | ||||
| 
 | ||||
|   mlir::Builder& builder; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace mlir
 | ||||
|  | @ -9,10 +9,10 @@ | |||
| #include <iostream> | ||||
| #include <queue> | ||||
| 
 | ||||
| #include "src/compiler/dialect/krnl/parser_helper.hpp" | ||||
| 
 | ||||
| #include "llvm/ADT/SetVector.h" | ||||
| #include "llvm/ADT/SmallBitVector.h" | ||||
| #include "mlir/Dialect/AffineOps/AffineOps.h" | ||||
| #include "mlir/Dialect/StandardOps/Ops.h" | ||||
| #include "mlir/IR/Block.h" | ||||
| #include "mlir/IR/Builders.h" | ||||
| #include "mlir/IR/Function.h" | ||||
|  | @ -24,6 +24,8 @@ | |||
| #include "mlir/IR/PatternMatch.h" | ||||
| #include "mlir/Transforms/DialectConversion.h" | ||||
| 
 | ||||
| #include "src/compiler/dialect/krnl/krnl_helper.hpp" | ||||
| 
 | ||||
| #include "krnl_ops.hpp" | ||||
| 
 | ||||
| using namespace mlir; | ||||
|  | @ -52,26 +54,25 @@ void KrnlDefineLoopsOp::build( | |||
| } | ||||
| 
 | ||||
| void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) { | ||||
|   auto num_loop_attr = op.getAttrOfType<IntegerAttr>(op.getNumLoopsAttrName()); | ||||
|   p << "krnl.define_loops " << num_loop_attr.getValue().getSExtValue(); | ||||
|   auto numLoopAttr = | ||||
|       op.getAttrOfType<IntegerAttr>(KrnlDefineLoopsOp::getNumLoopsAttrName()); | ||||
|   p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue(); | ||||
| } | ||||
| 
 | ||||
| ParseResult parseKrnlDefineLoopsOp( | ||||
|     OpAsmParser& parser, OperationState& result) { | ||||
|   // Parse the attribute indicating number of loops defined.
 | ||||
|   IntegerAttr num_loops; | ||||
|   IntegerAttr numLoops; | ||||
|   auto& builder = parser.getBuilder(); | ||||
|   auto int32_type = builder.getIntegerType(64); | ||||
|   if (parser.parseAttribute(num_loops, int32_type, | ||||
|   auto intType = builder.getIntegerType(64); | ||||
|   if (parser.parseAttribute(numLoops, intType, | ||||
|           KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes)) | ||||
|     return failure(); | ||||
| 
 | ||||
|   auto loop_types = llvm::SmallVector<Type, 4>( | ||||
|       num_loops.getValue().getSExtValue(), LoopType::get(builder.getContext())); | ||||
|   if (parser.addTypesToList(loop_types, result.types)) | ||||
|   auto loopTypes = llvm::SmallVector<Type, 4>( | ||||
|       numLoops.getValue().getSExtValue(), LoopType::get(builder.getContext())); | ||||
|   if (parser.addTypesToList(loopTypes, result.types)) | ||||
|     return failure(); | ||||
| 
 | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  | @ -142,39 +143,14 @@ ParseResult parseKrnlOptimizeLoopsOp( | |||
|  *   %i0 = 10 to N : %i1 = M to 20 | ||||
|  */ | ||||
| void KrnlIterateOp::build(Builder* builder, OperationState& result, | ||||
|     ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops, | ||||
|     ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds, | ||||
|     ArrayRef<int> bound_types) { | ||||
|     KrnlIterateOperandPack operandPack) { | ||||
|   // Record optimized loops and the number of such loops.
 | ||||
|   result.addOperands(optimized_loops); | ||||
|   result.addOperands(operandPack.getOperands()); | ||||
|   result.addAttribute( | ||||
|       KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes()); | ||||
| 
 | ||||
|   result.addAttribute(getNumOptimizedLoopsAttrName(), | ||||
|       builder->getI64IntegerAttr(optimized_loops.size())); | ||||
| 
 | ||||
|   // Record input loops and the number of such loops.
 | ||||
|   result.addOperands(input_loops); | ||||
|   result.addAttribute(getNumInputLoopsAttrName(), | ||||
|       builder->getI64IntegerAttr(input_loops.size())); | ||||
| 
 | ||||
|   // Record bound either as attribute or from operand list.
 | ||||
|   auto next_operand_bound = operand_bounds.begin(); | ||||
|   auto next_const_bound = const_bounds.begin(); | ||||
|   for (size_t i = 0; i < bound_types.size(); i++) { | ||||
|     auto bound_type = bound_types[i]; | ||||
|     if (bound_type == 0) { | ||||
|       // Constant bound.
 | ||||
|       result.addAttribute(getBoundAttrName(i / 2, i % 2), | ||||
|           builder->getI64IntegerAttr(*next_const_bound)); | ||||
|       next_const_bound = std::next(next_const_bound); | ||||
|     } else { | ||||
|       // Operand bound.
 | ||||
|       result.addOperands(*next_operand_bound); | ||||
|       next_operand_bound = std::next(next_operand_bound); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Record bound types as attribute:
 | ||||
|   result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(), | ||||
|       builder->getI32ArrayAttr(bound_types)); | ||||
|       builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops())); | ||||
| 
 | ||||
|   // Create a region and a block for the body. The arguments of the region are
 | ||||
|   // the loop induction variables; there can be multiple induction variables
 | ||||
|  | @ -182,7 +158,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result, | |||
|   Region* bodyRegion = result.addRegion(); | ||||
|   auto* body = new Block(); | ||||
|   auto body_args = llvm::SmallVector<Type, 4>( | ||||
|       input_loops.size(), IndexType::get(builder->getContext())); | ||||
|       operandPack.getNumInputLoops(), IndexType::get(builder->getContext())); | ||||
|   body->addArguments(body_args); | ||||
|   bodyRegion->push_back(body); | ||||
| 
 | ||||
|  | @ -192,57 +168,31 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result, | |||
| void print(OpAsmPrinter& p, KrnlIterateOp& op) { | ||||
|   p << "krnl.iterate("; | ||||
|   // Print optimized loops:
 | ||||
|   auto num_optimized_loops = op.getNumOptimizedLoops(); | ||||
|   p.printOperands(op.operand_begin(), op.operand_begin() + num_optimized_loops); | ||||
|   auto numOptimizedLoops = op.getNumOptimizedLoops(); | ||||
|   p.printOperands(op.operand_begin(), op.operand_begin() + numOptimizedLoops); | ||||
|   p << ") with ("; | ||||
| 
 | ||||
|   // Set up iterator to input loops:
 | ||||
|   auto num_input_loops = op.getNumInputLoops(); | ||||
|   auto input_loop_begin = op.operand_begin() + num_optimized_loops; | ||||
|   auto inductionVars = op.bodyRegion().begin()->getArguments(); | ||||
|   auto boundItr = | ||||
|       op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName()) | ||||
|           .getValue() | ||||
|           .begin(); | ||||
|   auto operandItr = op.operand_begin() + numOptimizedLoops; | ||||
| 
 | ||||
|   // Set up iterators to operand bounds.
 | ||||
|   auto next_operand_bound = input_loop_begin + num_input_loops; | ||||
| 
 | ||||
|   // Function to print a lower or upper bound.
 | ||||
|   auto print_bound = [&](ArrayRef<Attribute> bound_types, size_t idx) { | ||||
|     IntegerAttr type = bound_types[idx].dyn_cast<IntegerAttr>(); | ||||
|     if (type.getValue().getSExtValue() == 0) { | ||||
|       // Bound is an integer attribute.
 | ||||
|       auto bound_idx = idx / 2; | ||||
|       auto is_ub = idx % 2; | ||||
|       IntegerAttr bound = op.getAttrOfType<IntegerAttr>( | ||||
|           KrnlIterateOp::getBoundAttrName(bound_idx, is_ub)); | ||||
|       p << bound.getValue().getSExtValue(); | ||||
|     } else { | ||||
|       // Bound is an operand.
 | ||||
|       p.printOperand(*next_operand_bound); | ||||
|       next_operand_bound = std::next(next_operand_bound); | ||||
|     } | ||||
|   }; | ||||
| 
 | ||||
|   auto induction_variables = op.bodyRegion().front().getArguments(); | ||||
|   ArrayRef<Attribute> bound_types = | ||||
|       op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundTypesAttrName()) | ||||
|           .getValue(); | ||||
| 
 | ||||
|   // Print input loop operands, induction variables and their ranges.
 | ||||
|   for (size_t i = 0; i < num_input_loops; i++) { | ||||
|     if (i != 0) | ||||
|       p << ", "; | ||||
| 
 | ||||
|     p.printOperand(*std::next(input_loop_begin, i)); | ||||
|   std::string delimiter; | ||||
|   for (auto& var : inductionVars) { | ||||
|     p << delimiter; | ||||
|     p.printOperand(*operandItr++); | ||||
|     p << " -> "; | ||||
| 
 | ||||
|     // Print induction variable block argument.
 | ||||
|     p.printOperand(induction_variables[i]); | ||||
|     p.printOperand(var); | ||||
|     p << " = "; | ||||
| 
 | ||||
|     print_bound(bound_types, 2 * i);  // Print lower bound.
 | ||||
|     onnf::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "max", p); | ||||
|     p << " to "; | ||||
|     print_bound(bound_types, 2 * i + 1);  // Print upper bound.
 | ||||
|     onnf::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "min", p); | ||||
|     delimiter = ", "; | ||||
|   } | ||||
|   p << ")"; | ||||
| 
 | ||||
|   p << ")"; | ||||
|   p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false, | ||||
|       /*printBlockTerminators=*/false); | ||||
| } | ||||
|  | @ -250,80 +200,109 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) { | |||
| ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { | ||||
|   auto builder = parser.getBuilder(); | ||||
|   auto context = builder.getContext(); | ||||
|   onnf::KrnlDialectOperandParser operand_parser(parser); | ||||
|   onnf::KrnlDialectOperandParser operandParser(parser); | ||||
| 
 | ||||
|   // Parse optimized loops:
 | ||||
|   SmallVector<OpAsmParser::OperandType, 4> num_optimized_loops; | ||||
|   SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs; | ||||
|   if (parser.parseOperandList( | ||||
|           num_optimized_loops, OpAsmParser::Delimiter::Paren) || | ||||
|       parser.resolveOperands(num_optimized_loops, | ||||
|           optimizedLoopRefs, OpAsmParser::Delimiter::Paren) || | ||||
|       parser.resolveOperands(optimizedLoopRefs, | ||||
|           LoopType::get(result.getContext()), result.operands)) | ||||
|     return failure(); | ||||
| 
 | ||||
|   // Record how many optimized loops did we parse.
 | ||||
|   result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(), | ||||
|       builder.getI64IntegerAttr(num_optimized_loops.size())); | ||||
|       builder.getI64IntegerAttr(optimizedLoopRefs.size())); | ||||
| 
 | ||||
|   // Parse input loops and their lower and upper bounds.
 | ||||
|   SmallVector<OpAsmParser::OperandType, 4> in_loop_refs, induction_var_refs; | ||||
|   SmallVector<Value*, 4> in_loop_operands, operand_bounds; | ||||
|   SmallVector<Attribute, 4> bound_types; | ||||
|   SmallVector<IntegerAttr, 4> const_bounds; | ||||
|   SmallVector<OpAsmParser::OperandType, 4> inductionVarRefs; | ||||
|   SmallVector<Attribute, 4> boundMaps; | ||||
| 
 | ||||
|   if (parser.parseKeyword("with") || parser.parseLParen()) | ||||
|     return failure(); | ||||
| 
 | ||||
|   // A function to parse a lower or upper bound.
 | ||||
|   auto parse_bound = [&result, &builder, &operand_parser, &parser, &bound_types, | ||||
|                          &operand_bounds, &const_bounds]( | ||||
|                          bool is_ub, size_t bound_pair_count) -> ParseResult { | ||||
|   auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps]( | ||||
|                         bool isUpper) -> ParseResult { | ||||
|     // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
 | ||||
|     // the map has multiple results.
 | ||||
|     bool failedToParsedMinMax = | ||||
|         failed(parser.parseOptionalKeyword(isUpper ? "min" : "max")); | ||||
| 
 | ||||
|     // Try parse an SSA operand.
 | ||||
|     Value* bound; | ||||
|     operand_parser.ParseOptionalOperand(builder.getIndexType(), bound); | ||||
|     if (succeeded(operandParser.ParseOptionalOperand( | ||||
|             builder.getIndexType(), result.operands))) { | ||||
|       AffineMap map = builder.getSymbolIdentityMap(); | ||||
|       boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||
|       return success(); | ||||
|     } | ||||
| 
 | ||||
|     if (bound != nullptr) { | ||||
|       // Parsed an SSA id as bound.
 | ||||
|       operand_bounds.emplace_back(bound); | ||||
|       // Record bound_type as an operand type.
 | ||||
|       bound_types.emplace_back(builder.getI32IntegerAttr(0)); | ||||
|     } else { | ||||
|       // Bound is not an SSA id, then it must be an integer.
 | ||||
|       // Parse an integer constant attribute.
 | ||||
|       IntegerAttr boundAttr; | ||||
|       if (parser.parseAttribute(boundAttr, builder.getIndexType(), | ||||
|               KrnlIterateOp::getBoundAttrName(bound_pair_count, is_ub), | ||||
|               result.attributes)) | ||||
|     // Bound is not an SSA id, then it must be an integer.
 | ||||
|     // Parse an integer constant attribute.
 | ||||
|     // Get the attribute location.
 | ||||
|     llvm::SMLoc attrLoc = parser.getCurrentLocation(); | ||||
|     Attribute boundAttr; | ||||
|     llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer; | ||||
|     if (parser.parseAttribute( | ||||
|             boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer)) | ||||
|       return failure(); | ||||
| 
 | ||||
|     if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) { | ||||
|       unsigned currentNumOperands = result.operands.size(); | ||||
|       unsigned numDims = 0; | ||||
|       if (parseDimAndSymbolList(parser, result.operands, numDims)) | ||||
|         return failure(); | ||||
|       const_bounds.emplace_back( | ||||
|           builder.getIntegerAttr(builder.getIndexType(), boundAttr.getValue())); | ||||
| 
 | ||||
|       // Record that the bound_type is a constant integer attribute.
 | ||||
|       bound_types.emplace_back(builder.getI32IntegerAttr(1)); | ||||
|       auto map = affineMapAttr.getValue(); | ||||
|       if (map.getNumDims() != numDims) | ||||
|         return parser.emitError(parser.getNameLoc(), | ||||
|             "dim operand count and integer set dim count must match"); | ||||
| 
 | ||||
|       unsigned numDimAndSymbolOperands = | ||||
|           result.operands.size() - currentNumOperands; | ||||
|       if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) | ||||
|         return parser.emitError(parser.getNameLoc(), | ||||
|             "symbol operand count and integer set symbol count must match"); | ||||
| 
 | ||||
|       // If the map has multiple results, make sure that we parsed the min/max
 | ||||
|       // prefix.
 | ||||
|       if (map.getNumResults() > 1 && failedToParsedMinMax) { | ||||
|         if (isUpper) | ||||
|           return parser.emitError(attrLoc, | ||||
|               "upper loop bound affine map with multiple " | ||||
|               "results requires 'min' prefix"); | ||||
|         return parser.emitError(attrLoc, | ||||
|             "lower loop bound affine mapwith " | ||||
|             "multiple results requires 'max' prefix"); | ||||
|       } | ||||
|       boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||
|       return success(); | ||||
|     } | ||||
| 
 | ||||
|     if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) { | ||||
|       AffineMap map = | ||||
|           builder.getConstantAffineMap(integerAttr.getValue().getSExtValue()); | ||||
|       boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||
|     } | ||||
|   }; | ||||
| 
 | ||||
|   bool keep_parsing;            // Do we keep parsing loops/bounds?
 | ||||
|   size_t bound_pair_count = 0;  // Record the number of bound pairs parsed.
 | ||||
|   bool keepParsing;  // Do we keep parsing loops/bounds?
 | ||||
|   do { | ||||
|     // Parse an input loop operand;
 | ||||
|     Value* in_loop_operand; | ||||
|     operand_parser.ParseOperand(LoopType::get(context), in_loop_operand); | ||||
|     in_loop_operands.emplace_back(in_loop_operand); | ||||
| 
 | ||||
|     operandParser.ParseOperand(LoopType::get(context), result.operands); | ||||
|     parser.parseArrow(); | ||||
| 
 | ||||
|     // Parse induction variable.
 | ||||
|     OpAsmParser::OperandType induction_var; | ||||
|     if (parser.parseRegionArgument(induction_var) || parser.parseEqual()) | ||||
|     OpAsmParser::OperandType inductionVar; | ||||
|     if (parser.parseRegionArgument(inductionVar) || parser.parseEqual()) | ||||
|       return failure(); | ||||
|     induction_var_refs.emplace_back(induction_var); | ||||
|     inductionVarRefs.emplace_back(inductionVar); | ||||
| 
 | ||||
|     // Parse bound par (min to max).
 | ||||
|     if (parse_bound(false, bound_pair_count) || parser.parseKeyword("to") || | ||||
|         parse_bound(true, bound_pair_count)) | ||||
|     if (parseBound(/*isUpper=*/false) || parser.parseKeyword("to") || | ||||
|         parseBound(/*isUpper=*/true)) | ||||
|       return failure(); | ||||
| 
 | ||||
|     bound_pair_count++; | ||||
|     // We may fail to parse a comma if an operand bound is followed by
 | ||||
|     // a comma and the next input loop operand, in which case
 | ||||
|     // the entire "{operand bound}, {input_loop_operand}" sequence will
 | ||||
|  | @ -331,33 +310,19 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { | |||
|     parser.parseOptionalComma(); | ||||
| 
 | ||||
|     // If we don't see a RParen token, we keep parsing.
 | ||||
|     keep_parsing = failed(parser.parseOptionalRParen()); | ||||
|   } while (keep_parsing); | ||||
|     keepParsing = failed(parser.parseOptionalRParen()); | ||||
|   } while (keepParsing); | ||||
| 
 | ||||
|   // At this point, there shouldn't be any operands left to parse.
 | ||||
|   if (operand_parser.has_operand_left()) | ||||
|   if (operandParser.hasOperandLeft()) | ||||
|     return parser.emitError(parser.getCurrentLocation()); | ||||
|   result.addAttribute( | ||||
|       KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps)); | ||||
| 
 | ||||
|   // Record how many input loops did we parse.
 | ||||
|   result.addOperands(in_loop_operands); | ||||
|   result.addAttribute(KrnlIterateOp::getNumInputLoopsAttrName(), | ||||
|       builder.getI64IntegerAttr(in_loop_operands.size())); | ||||
| 
 | ||||
|   // Add operand bounds to the list of operands of current operation.
 | ||||
|   result.addOperands(operand_bounds); | ||||
| 
 | ||||
|   // A list of 2N elements where the (2n) and (2n+1) th element specifies
 | ||||
|   // whether the lower and upper bound of the n'th induction variable is stored
 | ||||
|   // as an operand or as an attribute. N being the number of input loops
 | ||||
|   // specified in this krnl.iterate operation.
 | ||||
|   result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(), | ||||
|       builder.getArrayAttr(bound_types)); | ||||
| 
 | ||||
|   // Parse the schedule body region.
 | ||||
|   Region* region = result.addRegion(); | ||||
|   SmallVector<Type, 4> induction_var_types( | ||||
|       induction_var_refs.size(), builder.getIndexType()); | ||||
|   if (parser.parseRegion(*region, induction_var_refs, induction_var_types)) | ||||
|   SmallVector<Type, 4> inductionVarTypes( | ||||
|       inductionVarRefs.size(), builder.getIndexType()); | ||||
|   if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes)) | ||||
|     return failure(); | ||||
| 
 | ||||
|   // Ensure iterate region is closed off with krnl.terminate.
 | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ | |||
| #include "mlir/IR/OpDefinition.h" | ||||
| #include "mlir/IR/StandardTypes.h" | ||||
| 
 | ||||
| #include "src/compiler/dialect/krnl/krnl_helper.hpp" | ||||
| #include "src/compiler/dialect/krnl/krnl_types.hpp" | ||||
| 
 | ||||
| namespace mlir { | ||||
|  |  | |||
|  | @ -8,6 +8,7 @@ | |||
| 
 | ||||
| include "mlir/IR/OpBase.td" | ||||
| 
 | ||||
| 
 | ||||
| def Krnl_Dialect : Dialect { | ||||
|   let name = "krnl"; | ||||
|   let cppNamespace = ""; | ||||
|  | @ -119,17 +120,14 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> { | |||
| 
 | ||||
|   let builders = [ | ||||
|     OpBuilder<"Builder *builder, OperationState &result, " | ||||
|               "ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops, " | ||||
|               "ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds, " | ||||
|               "ArrayRef<int> bound_types"> | ||||
|               "KrnlIterateOperandPack operandPack"> | ||||
|   ]; | ||||
| 
 | ||||
|   let extraClassDeclaration = [{ | ||||
| 
 | ||||
|     // In krnl.iterate operation, three types of SSA values are stored: | ||||
|     // In krnl.iterate operation, operands are stored as such | ||||
|     // - Optimized krnl.loops. | ||||
|     // - Input krnl.loops. | ||||
|     // - SSA value based induction variable bound (parametric bound). | ||||
|     // - Input krnl.loops and their operand bounds. (TODO(Tian) explain better how we store them). | ||||
| 
 | ||||
|     // We record the number of optimized and input loops to separate these three | ||||
|     // group of operands out. | ||||
|     static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; } | ||||
|  | @ -143,32 +141,8 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> { | |||
|       return num_optimized_loops; | ||||
|     } | ||||
| 
 | ||||
|     static StringRef getNumInputLoopsAttrName() { return "num_input_loops"; } | ||||
| 
 | ||||
|     int64_t getNumInputLoops() { | ||||
|       auto num_loops = | ||||
|         getAttrOfType<IntegerAttr>( | ||||
|                 getNumInputLoopsAttrName()) | ||||
|               .getValue() | ||||
|               .getSExtValue(); | ||||
|       return num_loops; | ||||
|     } | ||||
| 
 | ||||
|     // Constant bounds are stored here as a list attribute. | ||||
|     static StringRef getConstantBoundsAttrName() { return "constant_bounds"; } | ||||
| 
 | ||||
|     // Store type of each bound as three types: | ||||
|     // - 0 = constant attribute. | ||||
|     // - 1 = operand type. | ||||
|     // - 2 = affine maps (TODO). | ||||
|     static StringRef getBoundTypesAttrName() { return "bound_types"; } | ||||
| 
 | ||||
|     // Get dynamic attribute name for the i-th lower and upper bound. | ||||
|     static std::string getBoundAttrName(int64_t i, bool is_ub) { | ||||
|       std::string bound_type = is_ub ? "_ub" : "_lb"; | ||||
|       std::string bound_idx = std::to_string(i); | ||||
|       return "__bound_" + bound_idx + bound_type; | ||||
|     } | ||||
|     // Get name of the attribute for storing bound represented using affine maps. | ||||
|     static StringRef getBoundsAttrName() { return "bounds"; } | ||||
|     }]; | ||||
| 
 | ||||
|   let printer = [{ return ::print(p, *this); }]; | ||||
|  |  | |||
|  | @ -1,52 +0,0 @@ | |||
| //===------------------ parser_helper.cpp - MLIR Operations ---------------===//
 | ||||
| //
 | ||||
| // Copyright 2019 The IBM Research Authors.
 | ||||
| //
 | ||||
| // =============================================================================
 | ||||
| //
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| #include "parser_helper.hpp" | ||||
| 
 | ||||
| #include "src/compiler/dialect/krnl/krnl_ops.hpp" | ||||
| 
 | ||||
| namespace onnf { | ||||
| 
 | ||||
| mlir::ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ||||
|     mlir::Type operand_type, mlir::Value*& operand) { | ||||
|   // If operand queue is empty, parse more operands and cache them.
 | ||||
|   if (_operand_ref_queue.empty()) { | ||||
|     // Parse operand types:
 | ||||
|     llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> operand_refs; | ||||
|     _parser.parseOperandList(operand_refs); | ||||
| 
 | ||||
|     // Record operands:
 | ||||
|     for (auto& operand_ref : operand_refs) | ||||
|       _operand_ref_queue.emplace(operand_ref); | ||||
|   } | ||||
| 
 | ||||
|   // If we parsed some operand reference(s), resolve the ref to an operand:
 | ||||
|   if (!_operand_ref_queue.empty()) { | ||||
|     auto operand_ref = _operand_ref_queue.front(); | ||||
|     _operand_ref_queue.pop(); | ||||
| 
 | ||||
|     llvm::SmallVector<mlir::Value*, 1> operands; | ||||
|     _parser.resolveOperand(operand_ref, operand_type, operands); | ||||
|     operand = operands.front(); | ||||
|     return mlir::success(); | ||||
|   } else { | ||||
|     operand = nullptr; | ||||
|     return mlir::failure(); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| mlir::ParseResult KrnlDialectOperandParser::ParseOperand( | ||||
|     mlir::Type operand_type, mlir::Value*& operand) { | ||||
|   ParseOptionalOperand(operand_type, operand); | ||||
|   if (operand == nullptr) | ||||
|     return _parser.emitError( | ||||
|         _parser.getCurrentLocation(), "Expecting an operand."); | ||||
|   return mlir::success(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace onnf
 | ||||
|  | @ -1,46 +0,0 @@ | |||
| //===------------------ parser_helper.hpp - MLIR Operations ---------------===//
 | ||||
| //
 | ||||
| // Copyright 2019 The IBM Research Authors.
 | ||||
| //
 | ||||
| // =============================================================================
 | ||||
| //
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <queue> | ||||
| 
 | ||||
| #include "mlir/IR/Builders.h" | ||||
| #include "mlir/IR/Dialect.h" | ||||
| #include "mlir/IR/OpDefinition.h" | ||||
| #include "mlir/IR/OpImplementation.h" | ||||
| #include "mlir/IR/StandardTypes.h" | ||||
| 
 | ||||
| namespace onnf { | ||||
| 
 | ||||
| class KrnlDialectOperandParser { | ||||
|  public: | ||||
|   KrnlDialectOperandParser(mlir::OpAsmParser& parser) | ||||
|       : _parser(parser), _builder(parser.getBuilder()){}; | ||||
| 
 | ||||
|   // Parse an optional operand.
 | ||||
|   mlir::ParseResult ParseOptionalOperand( | ||||
|       mlir::Type operand_type, mlir::Value*& operand); | ||||
| 
 | ||||
|   // Parse a required operand.
 | ||||
|   mlir::ParseResult ParseOperand( | ||||
|       mlir::Type operand_type, mlir::Value*& operand); | ||||
| 
 | ||||
|   // Do we have more operands to parse?
 | ||||
|   bool has_operand_left() { return !_operand_ref_queue.empty(); } | ||||
| 
 | ||||
|  private: | ||||
|   mlir::OpAsmParser& _parser; | ||||
| 
 | ||||
|   mlir::Builder& _builder; | ||||
| 
 | ||||
|   // A queue storing the parsed SSA id references.
 | ||||
|   std::queue<mlir::OpAsmParser::OperandType> _operand_ref_queue; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace onnf
 | ||||
|  | @ -16,6 +16,7 @@ | |||
| #include "mlir/Pass/Pass.h" | ||||
| #include "mlir/Transforms/DialectConversion.h" | ||||
| 
 | ||||
| #include "src/compiler/dialect/krnl/krnl_helper.hpp" | ||||
| #include "src/compiler/dialect/krnl/krnl_ops.hpp" | ||||
| #include "src/compiler/dialect/onnx/onnx_ops.hpp" | ||||
| 
 | ||||
|  | @ -43,13 +44,12 @@ static MemRefType convertTensorToMemRef(TensorType type) { | |||
| } | ||||
| 
 | ||||
| /// Insert an allocation and deallocation for the given MemRefType.
 | ||||
| static Value* insertAllocAndDealloc( | ||||
|     MemRefType type, Location loc, PatternRewriter& rewriter, | ||||
|     Value *oldMemRef = nullptr) { | ||||
| static Value* insertAllocAndDealloc(MemRefType type, Location loc, | ||||
|     PatternRewriter& rewriter, Value* oldMemRef = nullptr) { | ||||
|   // Put together alloc operands for any dynamic dimensions of the memref.
 | ||||
|   AllocOp alloc; | ||||
|   if (oldMemRef) { | ||||
|     SmallVector<Value *, 4> allocOperands; | ||||
|     SmallVector<Value*, 4> allocOperands; | ||||
|     auto memRefShape = type.getShape(); | ||||
|     for (int i = 0; i < memRefShape.size(); ++i) | ||||
|       if (memRefShape[i] < 0) | ||||
|  | @ -95,7 +95,7 @@ struct ONNXAddOpLowering : public ConversionPattern { | |||
|     // dimensions with the result at this pre-optimization phase.
 | ||||
|     // TODO: verify that dimensions match.
 | ||||
|     // TODO: can the dimension of the result differ after optimizations?
 | ||||
|     Value *alloc; | ||||
|     Value* alloc; | ||||
|     if (hasAllConstantDimensions(memRefType)) | ||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter); | ||||
|     else | ||||
|  | @ -122,33 +122,22 @@ struct ONNXAddOpLowering : public ConversionPattern { | |||
|     } | ||||
|     Block& optimizationBlock = optimizedLoopsOp.region().front(); | ||||
| 
 | ||||
|     KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); | ||||
|     // Iterate over the loop nest.
 | ||||
|     // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape
 | ||||
|     // to KrnlIterateOp instead.
 | ||||
|     SmallVector<Value*, 8> operandBounds; | ||||
|     SmallVector<int64_t, 8> constBounds; | ||||
|     SmallVector<int, 16> boundTypes; | ||||
|     for (int i = 0; i < rank; ++i) { | ||||
|       if (memRefShape[i] < 0) { | ||||
|         // This is a dynamic value, hence use operands.
 | ||||
|         // Lower bound
 | ||||
|         constBounds.push_back(0); | ||||
|         boundTypes.push_back(0); | ||||
|         // Upper bound
 | ||||
|         operandBounds.push_back( | ||||
|         pack.pushConstantBound(0); | ||||
|         pack.pushOperandBound( | ||||
|             rewriter.create<DimOp>(loc, operands[0], i).getResult()); | ||||
|         boundTypes.push_back(1); | ||||
|       } else { | ||||
|         // Lower bound
 | ||||
|         constBounds.push_back(0); | ||||
|         boundTypes.push_back(0); | ||||
|         // Upper bound
 | ||||
|         constBounds.push_back(memRefShape[i]); | ||||
|         boundTypes.push_back(0); | ||||
|         pack.pushConstantBound(0); | ||||
|         pack.pushConstantBound(memRefShape[i]); | ||||
|       } | ||||
|     } | ||||
|     auto iterateOp = rewriter.create<KrnlIterateOp>(loc, originalLoops, | ||||
|         optimizedLoops, operandBounds, constBounds, boundTypes); | ||||
| 
 | ||||
|     auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack); | ||||
|     Block& iterationBlock = iterateOp.bodyRegion().front(); | ||||
| 
 | ||||
|     // Now perform the insertions into the body of the
 | ||||
|  | @ -169,14 +158,12 @@ struct ONNXAddOpLowering : public ConversionPattern { | |||
|     SmallVector<Value*, 4> loopIVs; | ||||
|     for (auto arg : iterationBlock.getArguments()) | ||||
|       loopIVs.push_back(arg); | ||||
|     auto loadedFirstVal = | ||||
|         rewriter.create<LoadOp>(loc, operands[0], loopIVs); | ||||
|     auto loadedSecondVal = | ||||
|         rewriter.create<LoadOp>(loc, operands[1], loopIVs); | ||||
|     auto loadedFirstVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs); | ||||
|     auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs); | ||||
| 
 | ||||
|     // TODO: Choose type of the Add for now use the Float Add.
 | ||||
|     auto addOpResult = rewriter.create<AddFOp>( | ||||
|         loc, loadedFirstVal, loadedSecondVal); | ||||
|     auto addOpResult = | ||||
|         rewriter.create<AddFOp>(loc, loadedFirstVal, loadedSecondVal); | ||||
| 
 | ||||
|     // Store result in the resulting array.
 | ||||
|     rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs); | ||||
|  | @ -209,8 +196,8 @@ struct TensorTypeConverter : public TypeConverter { | |||
|   /// inputs. Once unranked results can be handled gracefully this
 | ||||
|   /// override needs to be removed in favour of the original MLIR one.]
 | ||||
|   bool isSignatureLegal(FunctionType funcType) { | ||||
|     return llvm::all_of(funcType.getInputs(), | ||||
|         [this](Type type) { return isLegal(type); }); | ||||
|     return llvm::all_of( | ||||
|         funcType.getInputs(), [this](Type type) { return isLegal(type); }); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
|  | @ -272,8 +259,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { | |||
|   // With the target and rewrite patterns defined, we can now attempt the
 | ||||
|   // conversion. The conversion will signal failure if any of our `illegal`
 | ||||
|   // operations were not converted successfully.
 | ||||
|   if (failed(applyPartialConversion( | ||||
|           module, target, patterns))) | ||||
|   if (failed(applyPartialConversion(module, target, patterns))) | ||||
|     signalPassFailure(); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -17,9 +17,10 @@ class Pass; | |||
| 
 | ||||
| std::unique_ptr<Pass> createShapeInferencePass(); | ||||
| 
 | ||||
| /// Pass for lowering frontend dialects to Krnl IR dialect.
 | ||||
| std::unique_ptr<mlir::Pass> createLowerToKrnlPass(); | ||||
| /// Add pass for lowering to Krnl IR.
 | ||||
| std::unique_ptr<Pass> createLowerToKrnlPass(); | ||||
| 
 | ||||
| // TODO: Add pass for lowering to LLVM IR.
 | ||||
| /// Pass for lowering frontend dialects to Krnl IR dialect.
 | ||||
| std::unique_ptr<Pass> createLowerKrnlPass(); | ||||
| 
 | ||||
| }  // end namespace mlir
 | ||||
|  |  | |||
|  | @ -75,7 +75,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { | |||
|     if (auto terminator_op = f.getBody().back().getTerminator()) { | ||||
|       auto results = terminator_op->getOperandTypes(); | ||||
|       f.setType(FunctionType::get(f.getType().getInputs(), | ||||
|                 std::vector<Type>(results.begin(), results.end()), f.getContext())); | ||||
|           std::vector<Type>(results.begin(), results.end()), f.getContext())); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  |  | |||
|  | @ -6,8 +6,7 @@ target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT}) | |||
| 
 | ||||
| target_link_libraries(onnf-opt compiler ${MLIRLibs}) | ||||
| whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs}) | ||||
| whole_archive_link_onnf(onnf-opt onnf_lower_frontend) | ||||
| whole_archive_link_onnf(onnf-opt onnf_shape_inference) | ||||
| whole_archive_link_onnf(onnf-opt onnf_transform onnf_lower_frontend onnf_shape_inference) | ||||
| 
 | ||||
| # TODO: need to investigate how to whole-archive link compiler pass to onnf-opt. | ||||
| target_link_libraries(onnf-opt compiler) | ||||
|  |  | |||
|  | @ -0,0 +1,7 @@ | |||
| add_library(onnf_transform lower_krnl.cpp) | ||||
| 
 | ||||
| target_include_directories(onnf_transform | ||||
|                            PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} | ||||
|                                    ${ONNF_SRC_ROOT}) | ||||
| target_link_libraries(onnf_transform ${MLIRLibs}) | ||||
| add_dependencies(onnf_transform gen_krnl_ops) | ||||
|  | @ -0,0 +1,161 @@ | |||
| #include "mlir/Dialect/AffineOps/AffineOps.h" | ||||
| #include "mlir/Dialect/StandardOps/Ops.h" | ||||
| #include "mlir/Pass/Pass.h" | ||||
| #include "mlir/Transforms/DialectConversion.h" | ||||
| 
 | ||||
| #include "src/compiler/dialect/krnl/krnl_ops.hpp" | ||||
| #include "src/compiler/pass/passes.hpp" | ||||
| 
 | ||||
| using namespace mlir; | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Krnl to Affine Rewrite Patterns: KrnlIterate operation.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> { | ||||
|   using OpRewritePattern<KrnlIterateOp>::OpRewritePattern; | ||||
| 
 | ||||
|   PatternMatchResult matchAndRewrite( | ||||
|       KrnlIterateOp iterateOp, PatternRewriter& rewriter) const override { | ||||
|     auto boundMapAttrs = | ||||
|         iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName()) | ||||
|             .getValue(); | ||||
|     auto operandItr = | ||||
|         iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops(); | ||||
|     SmallVector<AffineForOp, 4> nestedForOps; | ||||
|     for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) { | ||||
|       // Consume input loop operand, currently do not do anything with it.
 | ||||
|       operandItr++; | ||||
| 
 | ||||
|       // Organize operands into lower/upper bounds in affine.for ready formats.
 | ||||
|       SmallVector<Value*, 4> lbOperands, ubOperands; | ||||
|       AffineMap lbMap, ubMap; | ||||
|       for (int boundType = 0; boundType < 2; boundType++) { | ||||
|         auto& operands = boundType == 0 ? lbOperands : ubOperands; | ||||
|         auto& map = boundType == 0 ? lbMap : ubMap; | ||||
|         map = boundMapAttrs[boundIdx + boundType] | ||||
|                   .cast<AffineMapAttr>() | ||||
|                   .getValue(); | ||||
|         operands.insert( | ||||
|             operands.end(), operandItr, operandItr + map.getNumInputs()); | ||||
|         std::advance(operandItr, map.getNumInputs()); | ||||
|       } | ||||
| 
 | ||||
|       nestedForOps.emplace_back(rewriter.create<AffineForOp>( | ||||
|           iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap)); | ||||
|       rewriter.setInsertionPoint(nestedForOps.back().getBody(), | ||||
|           nestedForOps.back().getBody()->begin()); | ||||
|     } | ||||
| 
 | ||||
|     // Replace induction variable references from those introduced by a
 | ||||
|     // single krnl.iterate to those introduced by multiple affine.for
 | ||||
|     // operations.
 | ||||
|     for (size_t i = 0; i < nestedForOps.size() - 1; i++) { | ||||
|       auto iterateIV = iterateOp.bodyRegion().front().getArgument(0); | ||||
|       auto forIV = nestedForOps[i].getBody()->getArgument(0); | ||||
|       iterateIV->replaceAllUsesWith(forIV); | ||||
|       iterateOp.bodyRegion().front().eraseArgument(0); | ||||
|     } | ||||
| 
 | ||||
|     // Pop krnl.iterate body region block arguments, leave the last one
 | ||||
|     // for convenience (it'll be taken care of by region inlining).
 | ||||
|     while (iterateOp.bodyRegion().front().getNumArguments() > 1) | ||||
|       iterateOp.bodyRegion().front().eraseArgument(0); | ||||
| 
 | ||||
|     // Transfer krnl.iterate region to innermost for op.
 | ||||
|     auto innermostForOp = nestedForOps.back(); | ||||
|     innermostForOp.region().getBlocks().clear(); | ||||
|     rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(), | ||||
|         innermostForOp.region().end()); | ||||
| 
 | ||||
|     rewriter.eraseOp(iterateOp); | ||||
|     return matchSuccess(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Krnl to Affine Rewrite Patterns: KrnlTerminator operation.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> { | ||||
|  public: | ||||
|   using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern; | ||||
| 
 | ||||
|   PatternMatchResult matchAndRewrite( | ||||
|       KrnlTerminatorOp op, PatternRewriter& rewriter) const override { | ||||
|     rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op); | ||||
|     return matchSuccess(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Krnl to Affine Rewrite Patterns: KrnlDefineLoops operation.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> { | ||||
|  public: | ||||
|   using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern; | ||||
| 
 | ||||
|   PatternMatchResult matchAndRewrite( | ||||
|       KrnlDefineLoopsOp op, PatternRewriter& rewriter) const override { | ||||
|     rewriter.eraseOp(op); | ||||
|     return matchSuccess(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> { | ||||
|  public: | ||||
|   using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern; | ||||
| 
 | ||||
|   PatternMatchResult matchAndRewrite( | ||||
|       KrnlOptimizeLoopsOp op, PatternRewriter& rewriter) const override { | ||||
|     rewriter.eraseOp(op); | ||||
|     return matchSuccess(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // KrnlToAffineLoweringPass
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| /// This is a partial lowering to affine loops of the krnl dialect operations.
 | ||||
| /// At this stage the dialect will contain standard operations as well like
 | ||||
| /// add and multiply, this pass will leave these operations intact.
 | ||||
| namespace { | ||||
| struct KrnlToAffineLoweringPass | ||||
|     : public FunctionPass<KrnlToAffineLoweringPass> { | ||||
|   void runOnFunction() final; | ||||
| }; | ||||
| }  // end anonymous namespace.
 | ||||
| 
 | ||||
| void KrnlToAffineLoweringPass::runOnFunction() { | ||||
|   auto function = getFunction(); | ||||
| 
 | ||||
|   ConversionTarget target(getContext()); | ||||
| 
 | ||||
|   target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>(); | ||||
|   // We expect IR to be free of Krnl Dialect Ops.
 | ||||
|   target.addIllegalDialect<KrnlOpsDialect>(); | ||||
| 
 | ||||
|   OwningRewritePatternList patterns; | ||||
|   patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering, | ||||
|       KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext()); | ||||
| 
 | ||||
|   if (failed(applyPartialConversion(getFunction(), target, patterns))) | ||||
|     signalPassFailure(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| std::unique_ptr<Pass> mlir::createLowerKrnlPass() { | ||||
|   return std::make_unique<KrnlToAffineLoweringPass>(); | ||||
| } | ||||
| 
 | ||||
| static PassRegistration<KrnlToAffineLoweringPass> pass( | ||||
|     "lower-krnl", "Lower Krnl dialect."); | ||||
							
								
								
									
										16
									
								
								src/main.cpp
								
								
								
								
							
							
						
						
									
										16
									
								
								src/main.cpp
								
								
								
								
							|  | @ -28,6 +28,7 @@ | |||
| 
 | ||||
| #include <boost/program_options.hpp> | ||||
| 
 | ||||
| #include "llvm/Bitcode/BitcodeWriter.h" | ||||
| #include "llvm/Support/FileUtilities.h" | ||||
| #include "llvm/Support/Regex.h" | ||||
| #include "llvm/Support/SourceMgr.h" | ||||
|  | @ -38,6 +39,8 @@ | |||
| #include "src/compiler/pass/passes.hpp" | ||||
| 
 | ||||
| #include "mlir/Analysis/Verifier.h" | ||||
| #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" | ||||
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" | ||||
| #include "mlir/ExecutionEngine/ExecutionEngine.h" | ||||
| #include "mlir/ExecutionEngine/OptUtils.h" | ||||
| #include "mlir/IR/MLIRContext.h" | ||||
|  | @ -125,7 +128,20 @@ int main(int ac, char* av[]) { | |||
|   pm.addPass(mlir::createShapeInferencePass()); | ||||
|   pm.addPass(mlir::createCanonicalizerPass()); | ||||
|   pm.addPass(mlir::createLowerToKrnlPass()); | ||||
|   pm.addPass(mlir::createLowerKrnlPass()); | ||||
|   pm.addPass(mlir::createLowerAffinePass()); | ||||
|   pm.addPass(mlir::createLowerToCFGPass()); | ||||
|   pm.addPass(mlir::createLowerToLLVMPass()); | ||||
|   pm.addPass(mlir::createCanonicalizerPass()); | ||||
|   pm.run(*module); | ||||
| 
 | ||||
|   // Write LLVM bitcode to disk.
 | ||||
|   std::error_code EC; | ||||
|   llvm::raw_fd_ostream moduleBitcodeStream( | ||||
|       "model.bc", EC, llvm::sys::fs::F_None); | ||||
|   llvm::WriteBitcodeToFile( | ||||
|       *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); | ||||
|   moduleBitcodeStream.flush(); | ||||
| 
 | ||||
|   return 0; | ||||
| } | ||||
|  |  | |||
|  | @ -0,0 +1,75 @@ | |||
| // RUN: onnf-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s | ||||
| // RUN: onnf-opt %s | FileCheck %s | ||||
| 
 | ||||
| // GENERIC-DAG: #{{.*}} = () -> (0) | ||||
| // GENERIC-DAG: #{{.*}} = () -> (10) | ||||
| // GENERIC-DAG: #{{.*}} = () -> (1) | ||||
| // GENERIC-DAG: #{{.*}} = () -> (11) | ||||
| // GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 - d1) | ||||
| // GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 + d1) | ||||
| 
 | ||||
| func @simple_iterate(%N : index) { | ||||
|   %ii, %ij, %ik = krnl.define_loops 3 | ||||
|   %oi, %oj, %ok = krnl.optimize_loops  { | ||||
|     krnl.return_loops %ii, %ij, %ik | ||||
|   } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) | ||||
| 
 | ||||
|   // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||
|   // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): | ||||
|   // GENERIC-NEXT: "krnl.terminate"() : () -> () | ||||
|   // GENERIC-NEXT: bounds = [#{{.*}}, #{{.*}}, #{{.*}}, #{{.*}}] | ||||
| 
 | ||||
|   // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 1 to 11) { | ||||
|   krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 1 to 11) { | ||||
| 
 | ||||
|   } | ||||
| 
 | ||||
|   // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||
|   // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): | ||||
|   // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) { | ||||
|   krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 0 to 10) { | ||||
|     // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}) ( { | ||||
|     // GENERIC-NEXT: ^bb0(%{{.*}}: index): | ||||
|     // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10) { | ||||
|     krnl.iterate(%ok) with (%ik -> %k = 0 to 10) { | ||||
| 
 | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||
|   // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): | ||||
|   // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to %{{.*}}, %{{.*}} -> %{{.*}} = 0 to 10) { | ||||
|   krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to %N, %ij -> %j = 0 to 10) { | ||||
| 
 | ||||
|   } | ||||
| 
 | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| func @affine_map_bound(%N : index) { | ||||
|   %ii, %ij, %ik = krnl.define_loops 3 | ||||
|   %oi, %oj, %ok = krnl.optimize_loops  { | ||||
|     krnl.return_loops %ii, %ij, %ik | ||||
|   } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) | ||||
| 
 | ||||
|   // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||
|   // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): | ||||
|   // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) { | ||||
|   krnl.iterate(%oi, %oj) with (%ii -> %i = ()->(0)() to ()->(10)(), %ij -> %j = 0 to 10) { | ||||
|     // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||
|     // GENERIC-NEXT: ^bb0(%{{.*}}: index): | ||||
|     // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = #{{.*}}(%{{.*}}, %{{.*}}) to #{{.*}}(%{{.*}}, %{{.*}})) { | ||||
|     krnl.iterate(%ok) with (%ik -> %k = (d0, d1)->(d0 - d1)(%i, %j) to (d0, d1)->(d0 + d1)(%i, %j)) { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||
|     // GENERIC-NEXT: ^bb0(%{{.*}}: index): | ||||
|     // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = max #map{{.*}}(%{{.*}}, %{{.*}}) to min #map{{.*}}(%{{.*}}, %{{.*}})[%{{.*}}]) { | ||||
|     krnl.iterate(%ok) with (%ik -> %k = max (d0, d1)->(d0 - d1, 0)(%i, %j) to min (d0, d1)[s0]->(d0 + d1, s0)(%i, %j)[%N]) { | ||||
| 
 | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   return | ||||
| } | ||||
		Loading…
	
		Reference in New Issue