[MLIR] Refactor Krnl Dialect and Krnl Dialect Lowering (#375)
* Store bounds as affine map attributes & check in test cases with generic printer * Upgrading MLIR MLIR is outdated on Buildbot, rebuilding a newer version. * work with new version of mlir * check-in parser tests * custom printer * nit * bug fix * enable custom asm printer test * enable custom asm printer test * more consistent variable naming * test max/min * variable naming scheme change to MLIR style * can lower krnl to llvm * kernel -> llvm * comments * bug fix * try fixing ci * fix ci * deactivate model test * fix lit test * nit * fix z buildbot
This commit is contained in:
		
							parent
							
								
									652ce4b7d4
								
							
						
					
					
						commit
						b2a1103915
					
				|  | @ -18,6 +18,8 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) | ||||||
| #TODO(eventually enable the following) | #TODO(eventually enable the following) | ||||||
| #set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) | #set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) | ||||||
| 
 | 
 | ||||||
|  | include(MLIR.cmake) | ||||||
|  | 
 | ||||||
| add_subdirectory(third_party/onnx) | add_subdirectory(third_party/onnx) | ||||||
| add_subdirectory(third_party/benchmark) | add_subdirectory(third_party/benchmark) | ||||||
| add_subdirectory(third_party/pybind11) | add_subdirectory(third_party/pybind11) | ||||||
|  | @ -41,7 +43,6 @@ if(Boost_FOUND) | ||||||
|   include_directories(${Boost_INCLUDE_DIRS}) |   include_directories(${Boost_INCLUDE_DIRS}) | ||||||
| endif() | endif() | ||||||
| 
 | 
 | ||||||
| include(MLIR.cmake) |  | ||||||
| add_subdirectory(src/builder) | add_subdirectory(src/builder) | ||||||
| add_subdirectory(src/compiler) | add_subdirectory(src/compiler) | ||||||
| add_subdirectory(src) | add_subdirectory(src) | ||||||
|  |  | ||||||
							
								
								
									
										92
									
								
								MLIR.cmake
								
								
								
								
							
							
						
						
									
										92
									
								
								MLIR.cmake
								
								
								
								
							|  | @ -5,7 +5,7 @@ if(DEFINED ENV{LLVM_SRC}) | ||||||
|     message(STATUS "LLVM_SRC " ${LLVM_SRC}) |     message(STATUS "LLVM_SRC " ${LLVM_SRC}) | ||||||
|   else() |   else() | ||||||
|     message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: " |     message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: " | ||||||
|                         ${LLVM_SRC}) |             ${LLVM_SRC}) | ||||||
|   endif() |   endif() | ||||||
| else() | else() | ||||||
|   message(FATAL_ERROR "env variable LLVM_SRC not set") |   message(FATAL_ERROR "env variable LLVM_SRC not set") | ||||||
|  | @ -18,7 +18,7 @@ if(DEFINED ENV{LLVM_BUILD}) | ||||||
|     message(STATUS "LLVM_BUILD " ${LLVM_BUILD}) |     message(STATUS "LLVM_BUILD " ${LLVM_BUILD}) | ||||||
|   else() |   else() | ||||||
|     message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: " |     message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: " | ||||||
|                         ${LLVM_BUILD}) |             ${LLVM_BUILD}) | ||||||
|   endif() |   endif() | ||||||
| else() | else() | ||||||
|   message(FATAL_ERROR "env variable LLVM_BUILD not set") |   message(FATAL_ERROR "env variable LLVM_BUILD not set") | ||||||
|  | @ -39,9 +39,9 @@ set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir) | ||||||
| set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir) | set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir) | ||||||
| 
 | 
 | ||||||
| set( | set( | ||||||
|   MLIR_INCLUDE_PATHS |         MLIR_INCLUDE_PATHS | ||||||
|   ${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH} |         ${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH} | ||||||
|   ) | ) | ||||||
| include_directories(${MLIR_INCLUDE_PATHS}) | include_directories(${MLIR_INCLUDE_PATHS}) | ||||||
| 
 | 
 | ||||||
| # Threading libraries required due to parallel pass execution. | # Threading libraries required due to parallel pass execution. | ||||||
|  | @ -49,9 +49,9 @@ find_package(Threads REQUIRED) | ||||||
| 
 | 
 | ||||||
| function(find_mlir_lib lib) | function(find_mlir_lib lib) | ||||||
|   find_library(${lib} |   find_library(${lib} | ||||||
|                NAMES ${lib} |           NAMES ${lib} | ||||||
|                PATHS ${LLVM_PROJECT_LIB} |           PATHS ${LLVM_PROJECT_LIB} | ||||||
|                NO_DEFAULT_PATH) |           NO_DEFAULT_PATH) | ||||||
| endfunction(find_mlir_lib) | endfunction(find_mlir_lib) | ||||||
| 
 | 
 | ||||||
| find_mlir_lib(MLIRAffineOps) | find_mlir_lib(MLIRAffineOps) | ||||||
|  | @ -70,6 +70,10 @@ find_mlir_lib(MLIRTransforms) | ||||||
| find_mlir_lib(MLIRTransformUtils) | find_mlir_lib(MLIRTransformUtils) | ||||||
| find_mlir_lib(MLIRSupport) | find_mlir_lib(MLIRSupport) | ||||||
| find_mlir_lib(MLIROptMain) | find_mlir_lib(MLIROptMain) | ||||||
|  | find_mlir_lib(MLIRTargetLLVMIRModuleTranslation) | ||||||
|  | find_mlir_lib(MLIRTargetLLVMIR) | ||||||
|  | find_mlir_lib(MLIRTransformUtils) | ||||||
|  | find_mlir_lib(MLIRTranslation) | ||||||
| find_mlir_lib(MLIRVectorOps) | find_mlir_lib(MLIRVectorOps) | ||||||
| 
 | 
 | ||||||
| find_mlir_lib(LLVMCore) | find_mlir_lib(LLVMCore) | ||||||
|  | @ -80,46 +84,52 @@ find_mlir_lib(LLVMRemarks) | ||||||
| find_mlir_lib(LLVMIRReader) | find_mlir_lib(LLVMIRReader) | ||||||
| find_mlir_lib(LLVMTransformUtils) | find_mlir_lib(LLVMTransformUtils) | ||||||
| find_mlir_lib(LLVMBitstreamReader) | find_mlir_lib(LLVMBitstreamReader) | ||||||
|  | find_mlir_lib(LLVMAnalysis) | ||||||
|  | find_mlir_lib(LLVMBitWriter) | ||||||
|  | find_mlir_lib(LLVMBitReader) | ||||||
|  | find_mlir_lib(LLVMMC) | ||||||
|  | find_mlir_lib(LLVMMCParser) | ||||||
|  | find_mlir_lib(LLVMObject) | ||||||
|  | find_mlir_lib(LLVMProfileData) | ||||||
|  | find_mlir_lib(LLVMDemangle) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| set(MLIRLibsOnce | set(MLIRLibsOnce | ||||||
|  |         LLVMAnalysis | ||||||
|  |         LLVMAsmParser | ||||||
|  |         LLVMBinaryFormat | ||||||
|  |         LLVMBitReader | ||||||
|  |         LLVMBitstreamReader | ||||||
|  |         LLVMBitWriter | ||||||
|  |         LLVMCore | ||||||
|  |         LLVMIRReader | ||||||
|  |         LLVMMC | ||||||
|  |         LLVMMCParser | ||||||
|  |         LLVMObject | ||||||
|  |         LLVMRemarks | ||||||
|  |         LLVMSupport | ||||||
|  |         LLVMTransformUtils | ||||||
|  |         LLVMProfileData | ||||||
|  |         LLVMDemangle | ||||||
|         MLIRAffineOps |         MLIRAffineOps | ||||||
|         MLIRAffineToStandard |         MLIRAffineToStandard | ||||||
|         MLIRAnalysis |         MLIRAnalysis | ||||||
|         MLIRExecutionEngine |         MLIRExecutionEngine | ||||||
|         MLIRIR |         MLIRIR | ||||||
|         MLIRLLVMIR |         MLIRLLVMIR | ||||||
|  |         MLIRLoopOps | ||||||
|         MLIRLoopToStandard |         MLIRLoopToStandard | ||||||
|  |         MLIROptMain | ||||||
|         MLIRParser |         MLIRParser | ||||||
|         MLIRPass |         MLIRPass | ||||||
|         MLIRStandardOps |         MLIRStandardOps | ||||||
|         MLIRStandardToLLVM |         MLIRStandardToLLVM | ||||||
|  |         MLIRSupport | ||||||
|         MLIRTargetLLVMIR |         MLIRTargetLLVMIR | ||||||
|         MLIRTransforms |         MLIRTargetLLVMIRModuleTranslation | ||||||
|         MLIRAffineOps |  | ||||||
|         MLIRAffineToStandard |  | ||||||
|         MLIRAnalysis |  | ||||||
|         MLIRExecutionEngine |  | ||||||
|         MLIRIR |  | ||||||
|         MLIRLLVMIR |  | ||||||
|         MLIRLoopToStandard |  | ||||||
|         MLIRParser |  | ||||||
|         MLIRPass |  | ||||||
|         MLIRStandardOps |  | ||||||
|         MLIRStandardToLLVM |  | ||||||
|         MLIRTargetLLVMIR |  | ||||||
|         MLIRTransforms |         MLIRTransforms | ||||||
|         MLIRTransformUtils |         MLIRTransformUtils | ||||||
|         MLIRLoopOps |         MLIRTranslation) | ||||||
|         MLIRSupport |  | ||||||
|         MLIROptMain |  | ||||||
|         LLVMCore |  | ||||||
|         LLVMSupport |  | ||||||
|         LLVMAsmParser |  | ||||||
|         LLVMIRReader |  | ||||||
|         LLVMTransformUtils |  | ||||||
|         LLVMBinaryFormat |  | ||||||
|         LLVMRemarks |  | ||||||
|         LLVMBitstreamReader) |  | ||||||
| 
 | 
 | ||||||
| set(MLIRLibs | set(MLIRLibs | ||||||
|         ${MLIRLibsOnce} |         ${MLIRLibsOnce} | ||||||
|  | @ -142,7 +152,7 @@ function(whole_archive_link target lib_dir) | ||||||
|     set(link_flags "${link_flags} -L${lib_dir} ") |     set(link_flags "${link_flags} -L${lib_dir} ") | ||||||
|     foreach(LIB ${ARGN}) |     foreach(LIB ${ARGN}) | ||||||
|       string(CONCAT link_flags ${link_flags} |       string(CONCAT link_flags ${link_flags} | ||||||
|                     "-Wl,-force_load ${lib_dir}/lib${LIB}.a ") |               "-Wl,-force_load ${lib_dir}/lib${LIB}.a ") | ||||||
|     endforeach(LIB) |     endforeach(LIB) | ||||||
|   elseif(MSVC) |   elseif(MSVC) | ||||||
|     foreach(LIB ${ARGN}) |     foreach(LIB ${ARGN}) | ||||||
|  | @ -170,20 +180,20 @@ function(whole_archive_link_onnf target) | ||||||
| endfunction(whole_archive_link_onnf) | endfunction(whole_archive_link_onnf) | ||||||
| 
 | 
 | ||||||
| set(LLVM_CMAKE_DIR | set(LLVM_CMAKE_DIR | ||||||
|     "${LLVM_BUILD}/lib/cmake/llvm" |         "${LLVM_BUILD}/lib/cmake/llvm" | ||||||
|     CACHE PATH "Path to LLVM cmake modules") |         CACHE PATH "Path to LLVM cmake modules") | ||||||
| list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") | list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") | ||||||
| include(AddLLVM) | include(AddLLVM) | ||||||
| include(TableGen) | include(TableGen) | ||||||
| 
 | 
 | ||||||
| function(onnf_tablegen ofn) | function(onnf_tablegen ofn) | ||||||
|   tablegen(MLIR |   tablegen(MLIR | ||||||
|            ${ARGV} |           ${ARGV} | ||||||
|            "-I${MLIR_SRC_INCLUDE_PATH}" |           "-I${MLIR_SRC_INCLUDE_PATH}" | ||||||
|            "-I${MLIR_BIN_INCLUDE_PATH}") |           "-I${MLIR_BIN_INCLUDE_PATH}") | ||||||
|   set(TABLEGEN_OUTPUT |   set(TABLEGEN_OUTPUT | ||||||
|       ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} |           ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} | ||||||
|       PARENT_SCOPE) |           PARENT_SCOPE) | ||||||
| endfunction() | endfunction() | ||||||
| 
 | 
 | ||||||
| # Import the pre-built mlir TableGen as an imported exetuable. It is required by | # Import the pre-built mlir TableGen as an imported exetuable. It is required by | ||||||
|  | @ -191,5 +201,5 @@ endfunction() | ||||||
| # table gen utility itself can be detected and cause re-compilation of .td file. | # table gen utility itself can be detected and cause re-compilation of .td file. | ||||||
| add_executable(mlir-tblgen IMPORTED) | add_executable(mlir-tblgen IMPORTED) | ||||||
| set_property(TARGET mlir-tblgen | set_property(TARGET mlir-tblgen | ||||||
|              PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen) |         PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen) | ||||||
| set(MLIR_TABLEGEN_EXE mlir-tblgen) | set(MLIR_TABLEGEN_EXE mlir-tblgen) | ||||||
|  |  | ||||||
|  | @ -2,6 +2,7 @@ add_executable(onnf main.cpp) | ||||||
| 
 | 
 | ||||||
| target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES}) | target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES}) | ||||||
| whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs}) | whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs}) | ||||||
|  | whole_archive_link_onnf(onnf onnf_transform) | ||||||
| 
 | 
 | ||||||
| target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR}) | target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR}) | ||||||
| target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR}) | target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR}) | ||||||
|  |  | ||||||
|  | @ -6,8 +6,8 @@ add_library( | ||||||
|         dialect/krnl/krnl_types.hpp |         dialect/krnl/krnl_types.hpp | ||||||
|         dialect/onnx/onnx_ops.cpp |         dialect/onnx/onnx_ops.cpp | ||||||
|         dialect/onnx/onnx_ops.hpp |         dialect/onnx/onnx_ops.hpp | ||||||
|         dialect/krnl/parser_helper.cpp |         dialect/krnl/krnl_helper.cpp | ||||||
|         dialect/krnl/parser_helper.hpp |         dialect/krnl/krnl_helper.hpp | ||||||
|         pass/shape_inference_pass.cpp |         pass/shape_inference_pass.cpp | ||||||
|         pass/shape_inference_interface.hpp |         pass/shape_inference_interface.hpp | ||||||
|         dialect/onnx/onnxop.inc |         dialect/onnx/onnxop.inc | ||||||
|  | @ -82,4 +82,5 @@ target_include_directories(onnf_lower_frontend | ||||||
| target_link_libraries(onnf_lower_frontend ${MLIRLibs}) | target_link_libraries(onnf_lower_frontend ${MLIRLibs}) | ||||||
| add_dependencies(onnf_lower_frontend gen_krnl_ops) | add_dependencies(onnf_lower_frontend gen_krnl_ops) | ||||||
| 
 | 
 | ||||||
|  | add_subdirectory(transform) | ||||||
| add_subdirectory(tool) | add_subdirectory(tool) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,139 @@ | ||||||
|  | #include "mlir/Dialect/AffineOps/AffineOps.h" | ||||||
|  | #include "mlir/IR/AffineExpr.h" | ||||||
|  | 
 | ||||||
|  | #include "src/compiler/dialect/krnl/krnl_ops.hpp" | ||||||
|  | 
 | ||||||
|  | #include "krnl_helper.hpp" | ||||||
|  | 
 | ||||||
|  | namespace onnf { | ||||||
|  | 
 | ||||||
|  | using namespace mlir; | ||||||
|  | 
 | ||||||
|  | ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ||||||
|  |     const Type& operandType, Value*& operand) { | ||||||
|  |   // If operand queue is empty, parse more operands and cache them.
 | ||||||
|  |   if (_operandRefQueue.empty()) { | ||||||
|  |     // Parse operand types:
 | ||||||
|  |     llvm::SmallVector<OpAsmParser::OperandType, 2> operand_refs; | ||||||
|  |     _parser.parseOperandList(operand_refs); | ||||||
|  | 
 | ||||||
|  |     // Record operands:
 | ||||||
|  |     for (auto& operand_ref : operand_refs) | ||||||
|  |       _operandRefQueue.emplace(operand_ref); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // If we parsed some operand reference(s), resolve the ref to an operand:
 | ||||||
|  |   if (!_operandRefQueue.empty()) { | ||||||
|  |     auto operand_ref = _operandRefQueue.front(); | ||||||
|  |     _operandRefQueue.pop(); | ||||||
|  | 
 | ||||||
|  |     llvm::SmallVector<Value*, 1> operands; | ||||||
|  |     _parser.resolveOperand(operand_ref, operandType, operands); | ||||||
|  |     operand = operands.front(); | ||||||
|  |     return success(); | ||||||
|  |   } else { | ||||||
|  |     operand = nullptr; | ||||||
|  |     return failure(); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ||||||
|  |     const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) { | ||||||
|  |   Value* operand = nullptr; | ||||||
|  |   if (ParseOptionalOperand(operandType, operand)) | ||||||
|  |     return failure(); | ||||||
|  | 
 | ||||||
|  |   operandList.emplace_back(operand); | ||||||
|  |   return success(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ParseResult KrnlDialectOperandParser::ParseOperand( | ||||||
|  |     const Type& operandType, Value*& operand) { | ||||||
|  |   if (ParseOptionalOperand(operandType, operand)) | ||||||
|  |     return _parser.emitError( | ||||||
|  |         _parser.getCurrentLocation(), "Expecting an operand."); | ||||||
|  |   return success(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ParseResult KrnlDialectOperandParser::ParseOperand( | ||||||
|  |     const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) { | ||||||
|  |   if (ParseOptionalOperand(operandType, operandList)) | ||||||
|  |     return _parser.emitError( | ||||||
|  |         _parser.getCurrentLocation(), "Expecting an operand."); | ||||||
|  | 
 | ||||||
|  |   return success(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims, | ||||||
|  |     unsigned numSymbols, OpAsmPrinter& p) { | ||||||
|  |   p << '('; | ||||||
|  |   p.printOperands(begin, begin + numDims); | ||||||
|  |   p << ')'; | ||||||
|  | 
 | ||||||
|  |   if (numSymbols) { | ||||||
|  |     p << '['; | ||||||
|  |     p.printOperands(begin + numDims, begin + numDims + numSymbols); | ||||||
|  |     p << ']'; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   begin = std::next(begin, numDims + numSymbols); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void printBound(AffineMapAttr boundMap, | ||||||
|  |     Operation::operand_iterator& boundOperandsBeg, const char* prefix, | ||||||
|  |     OpAsmPrinter& p) { | ||||||
|  |   AffineMap map = boundMap.getValue(); | ||||||
|  | 
 | ||||||
|  |   // Check if this bound should be printed using custom assembly form.
 | ||||||
|  |   // The decision to restrict printing custom assembly form to trivial cases
 | ||||||
|  |   // comes from the will to roundtrip MLIR binary -> text -> binary in a
 | ||||||
|  |   // lossless way.
 | ||||||
|  |   // Therefore, custom assembly form parsing and printing is only supported for
 | ||||||
|  |   // zero-operand constant maps and single symbol operand identity maps.
 | ||||||
|  |   if (map.getNumResults() == 1) { | ||||||
|  |     AffineExpr expr = map.getResult(0); | ||||||
|  | 
 | ||||||
|  |     // Print constant bound.
 | ||||||
|  |     if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { | ||||||
|  |       if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { | ||||||
|  |         p << constExpr.getValue(); | ||||||
|  |         return; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Print bound that consists of a single SSA symbol if the map is over a
 | ||||||
|  |     // single symbol.
 | ||||||
|  |     if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { | ||||||
|  |       if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) { | ||||||
|  |         p.printOperand(*(boundOperandsBeg++)); | ||||||
|  |         return; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } else { | ||||||
|  |     // Map has multiple results. Print 'min' or 'max' prefix.
 | ||||||
|  |     p << prefix << ' '; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Print the map and its operands.
 | ||||||
|  |   p << boundMap; | ||||||
|  |   printDimAndSymbolList( | ||||||
|  |       boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p); | ||||||
|  | } | ||||||
|  | }  // namespace onnf
 | ||||||
|  | 
 | ||||||
|  | namespace mlir { | ||||||
|  | void KrnlIterateOperandPack::pushConstantBound(int64_t bound) { | ||||||
|  |   if (boundMaps.size() % 2 == 0) | ||||||
|  |     _operands.emplace_back(inputLoops[boundMaps.size() / 2]); | ||||||
|  |   AffineMap map = builder.getConstantAffineMap(bound); | ||||||
|  |   boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) { | ||||||
|  |   if (boundMaps.size() % 2 == 0) | ||||||
|  |     _operands.emplace_back(inputLoops[boundMaps.size() / 2]); | ||||||
|  |   AffineMap map = builder.getSymbolIdentityMap(); | ||||||
|  |   boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||||
|  |   _operands.emplace_back(operand); | ||||||
|  | } | ||||||
|  | }  // namespace mlir
 | ||||||
|  | @ -0,0 +1,102 @@ | ||||||
|  | #pragma once | ||||||
|  | 
 | ||||||
|  | #include <queue> | ||||||
|  | 
 | ||||||
|  | #include "mlir/IR/Attributes.h" | ||||||
|  | #include "mlir/IR/Builders.h" | ||||||
|  | #include "mlir/IR/Dialect.h" | ||||||
|  | #include "mlir/IR/OpDefinition.h" | ||||||
|  | #include "mlir/IR/OpImplementation.h" | ||||||
|  | #include "mlir/IR/StandardTypes.h" | ||||||
|  | 
 | ||||||
|  | namespace onnf { | ||||||
|  | 
 | ||||||
|  | class KrnlDialectOperandParser { | ||||||
|  |  public: | ||||||
|  |   explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser) | ||||||
|  |       : _parser(parser), _builder(parser.getBuilder()){}; | ||||||
|  | 
 | ||||||
|  |   // Parse an optional operand.
 | ||||||
|  |   mlir::ParseResult ParseOptionalOperand( | ||||||
|  |       const mlir::Type& operandType, mlir::Value*& operand); | ||||||
|  | 
 | ||||||
|  |   // Parse an optional operand and push it to an operand list.
 | ||||||
|  |   mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType, | ||||||
|  |       llvm::SmallVectorImpl<mlir::Value*>& operandList); | ||||||
|  | 
 | ||||||
|  |   // Parse a required operand.
 | ||||||
|  |   mlir::ParseResult ParseOperand( | ||||||
|  |       const mlir::Type& operandType, mlir::Value*& operand); | ||||||
|  | 
 | ||||||
|  |   // Parse a required operand and push it to an operand list.
 | ||||||
|  |   mlir::ParseResult ParseOperand(const mlir::Type& operandType, | ||||||
|  |       llvm::SmallVectorImpl<mlir::Value*>& operandList); | ||||||
|  | 
 | ||||||
|  |   // Do we have more operands to parse?
 | ||||||
|  |   bool hasOperandLeft() { return !_operandRefQueue.empty(); } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   mlir::OpAsmParser& _parser; | ||||||
|  | 
 | ||||||
|  |   mlir::Builder& _builder; | ||||||
|  | 
 | ||||||
|  |   // A queue storing the parsed SSA id references.
 | ||||||
|  |   std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // Adapted from:
 | ||||||
|  | // https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
 | ||||||
|  | // Main difference is that it advances the iterator `begin` as it consumes
 | ||||||
|  | // dimension and symbol operands.
 | ||||||
|  | void printDimAndSymbolList(mlir::Operation::operand_iterator& begin, | ||||||
|  |     unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p); | ||||||
|  | 
 | ||||||
|  | // Adapted from:
 | ||||||
|  | // https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
 | ||||||
|  | // Main difference is that it advances the iterator `boundOperandsBeg` as it
 | ||||||
|  | // prints bound.
 | ||||||
|  | void printBound(mlir::AffineMapAttr boundMap, | ||||||
|  |     mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix, | ||||||
|  |     mlir::OpAsmPrinter& p); | ||||||
|  | }  // namespace onnf
 | ||||||
|  | 
 | ||||||
|  | namespace mlir { | ||||||
|  | 
 | ||||||
|  | struct KrnlIterateOperandPack { | ||||||
|  |   KrnlIterateOperandPack(mlir::Builder& builder, | ||||||
|  |       llvm::ArrayRef<mlir::Value*> inputLoops, | ||||||
|  |       llvm::ArrayRef<mlir::Value*> optimizedLoops) | ||||||
|  |       : builder(builder), | ||||||
|  |         inputLoops(inputLoops), | ||||||
|  |         optimizedLoops(optimizedLoops) { | ||||||
|  |     _operands.insert( | ||||||
|  |         _operands.end(), optimizedLoops.begin(), optimizedLoops.end()); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   void pushConstantBound(int64_t bound); | ||||||
|  | 
 | ||||||
|  |   void pushOperandBound(mlir::Value* operand); | ||||||
|  | 
 | ||||||
|  |   llvm::SmallVector<mlir::Value*, 8> getOperands() const { return _operands; } | ||||||
|  | 
 | ||||||
|  |   mlir::ArrayAttr getAttributes() const { | ||||||
|  |     return builder.getArrayAttr(boundMaps); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   size_t getNumOptimizedLoops() const { return optimizedLoops.size(); } | ||||||
|  | 
 | ||||||
|  |   size_t getNumInputLoops() const { return inputLoops.size(); } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   int _boundIdx = 0; | ||||||
|  | 
 | ||||||
|  |   llvm::SmallVector<mlir::Value*, 8> _operands; | ||||||
|  | 
 | ||||||
|  |   llvm::SmallVector<mlir::Attribute, 8> boundMaps; | ||||||
|  | 
 | ||||||
|  |   llvm::ArrayRef<mlir::Value*> inputLoops, optimizedLoops; | ||||||
|  | 
 | ||||||
|  |   mlir::Builder& builder; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace mlir
 | ||||||
|  | @ -9,10 +9,10 @@ | ||||||
| #include <iostream> | #include <iostream> | ||||||
| #include <queue> | #include <queue> | ||||||
| 
 | 
 | ||||||
| #include "src/compiler/dialect/krnl/parser_helper.hpp" |  | ||||||
| 
 |  | ||||||
| #include "llvm/ADT/SetVector.h" | #include "llvm/ADT/SetVector.h" | ||||||
| #include "llvm/ADT/SmallBitVector.h" | #include "llvm/ADT/SmallBitVector.h" | ||||||
|  | #include "mlir/Dialect/AffineOps/AffineOps.h" | ||||||
|  | #include "mlir/Dialect/StandardOps/Ops.h" | ||||||
| #include "mlir/IR/Block.h" | #include "mlir/IR/Block.h" | ||||||
| #include "mlir/IR/Builders.h" | #include "mlir/IR/Builders.h" | ||||||
| #include "mlir/IR/Function.h" | #include "mlir/IR/Function.h" | ||||||
|  | @ -24,6 +24,8 @@ | ||||||
| #include "mlir/IR/PatternMatch.h" | #include "mlir/IR/PatternMatch.h" | ||||||
| #include "mlir/Transforms/DialectConversion.h" | #include "mlir/Transforms/DialectConversion.h" | ||||||
| 
 | 
 | ||||||
|  | #include "src/compiler/dialect/krnl/krnl_helper.hpp" | ||||||
|  | 
 | ||||||
| #include "krnl_ops.hpp" | #include "krnl_ops.hpp" | ||||||
| 
 | 
 | ||||||
| using namespace mlir; | using namespace mlir; | ||||||
|  | @ -52,26 +54,25 @@ void KrnlDefineLoopsOp::build( | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) { | void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) { | ||||||
|   auto num_loop_attr = op.getAttrOfType<IntegerAttr>(op.getNumLoopsAttrName()); |   auto numLoopAttr = | ||||||
|   p << "krnl.define_loops " << num_loop_attr.getValue().getSExtValue(); |       op.getAttrOfType<IntegerAttr>(KrnlDefineLoopsOp::getNumLoopsAttrName()); | ||||||
|  |   p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ParseResult parseKrnlDefineLoopsOp( | ParseResult parseKrnlDefineLoopsOp( | ||||||
|     OpAsmParser& parser, OperationState& result) { |     OpAsmParser& parser, OperationState& result) { | ||||||
|   // Parse the attribute indicating number of loops defined.
 |   // Parse the attribute indicating number of loops defined.
 | ||||||
|   IntegerAttr num_loops; |   IntegerAttr numLoops; | ||||||
|   auto& builder = parser.getBuilder(); |   auto& builder = parser.getBuilder(); | ||||||
|   auto int32_type = builder.getIntegerType(64); |   auto intType = builder.getIntegerType(64); | ||||||
|   if (parser.parseAttribute(num_loops, int32_type, |   if (parser.parseAttribute(numLoops, intType, | ||||||
|           KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes)) |           KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes)) | ||||||
|     return failure(); |     return failure(); | ||||||
| 
 | 
 | ||||||
|   auto loop_types = llvm::SmallVector<Type, 4>( |   auto loopTypes = llvm::SmallVector<Type, 4>( | ||||||
|       num_loops.getValue().getSExtValue(), LoopType::get(builder.getContext())); |       numLoops.getValue().getSExtValue(), LoopType::get(builder.getContext())); | ||||||
|   if (parser.addTypesToList(loop_types, result.types)) |   if (parser.addTypesToList(loopTypes, result.types)) | ||||||
|     return failure(); |     return failure(); | ||||||
| 
 |  | ||||||
|   return success(); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
|  | @ -142,39 +143,14 @@ ParseResult parseKrnlOptimizeLoopsOp( | ||||||
|  *   %i0 = 10 to N : %i1 = M to 20 |  *   %i0 = 10 to N : %i1 = M to 20 | ||||||
|  */ |  */ | ||||||
| void KrnlIterateOp::build(Builder* builder, OperationState& result, | void KrnlIterateOp::build(Builder* builder, OperationState& result, | ||||||
|     ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops, |     KrnlIterateOperandPack operandPack) { | ||||||
|     ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds, |  | ||||||
|     ArrayRef<int> bound_types) { |  | ||||||
|   // Record optimized loops and the number of such loops.
 |   // Record optimized loops and the number of such loops.
 | ||||||
|   result.addOperands(optimized_loops); |   result.addOperands(operandPack.getOperands()); | ||||||
|  |   result.addAttribute( | ||||||
|  |       KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes()); | ||||||
|  | 
 | ||||||
|   result.addAttribute(getNumOptimizedLoopsAttrName(), |   result.addAttribute(getNumOptimizedLoopsAttrName(), | ||||||
|       builder->getI64IntegerAttr(optimized_loops.size())); |       builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops())); | ||||||
| 
 |  | ||||||
|   // Record input loops and the number of such loops.
 |  | ||||||
|   result.addOperands(input_loops); |  | ||||||
|   result.addAttribute(getNumInputLoopsAttrName(), |  | ||||||
|       builder->getI64IntegerAttr(input_loops.size())); |  | ||||||
| 
 |  | ||||||
|   // Record bound either as attribute or from operand list.
 |  | ||||||
|   auto next_operand_bound = operand_bounds.begin(); |  | ||||||
|   auto next_const_bound = const_bounds.begin(); |  | ||||||
|   for (size_t i = 0; i < bound_types.size(); i++) { |  | ||||||
|     auto bound_type = bound_types[i]; |  | ||||||
|     if (bound_type == 0) { |  | ||||||
|       // Constant bound.
 |  | ||||||
|       result.addAttribute(getBoundAttrName(i / 2, i % 2), |  | ||||||
|           builder->getI64IntegerAttr(*next_const_bound)); |  | ||||||
|       next_const_bound = std::next(next_const_bound); |  | ||||||
|     } else { |  | ||||||
|       // Operand bound.
 |  | ||||||
|       result.addOperands(*next_operand_bound); |  | ||||||
|       next_operand_bound = std::next(next_operand_bound); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   // Record bound types as attribute:
 |  | ||||||
|   result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(), |  | ||||||
|       builder->getI32ArrayAttr(bound_types)); |  | ||||||
| 
 | 
 | ||||||
|   // Create a region and a block for the body. The arguments of the region are
 |   // Create a region and a block for the body. The arguments of the region are
 | ||||||
|   // the loop induction variables; there can be multiple induction variables
 |   // the loop induction variables; there can be multiple induction variables
 | ||||||
|  | @ -182,7 +158,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result, | ||||||
|   Region* bodyRegion = result.addRegion(); |   Region* bodyRegion = result.addRegion(); | ||||||
|   auto* body = new Block(); |   auto* body = new Block(); | ||||||
|   auto body_args = llvm::SmallVector<Type, 4>( |   auto body_args = llvm::SmallVector<Type, 4>( | ||||||
|       input_loops.size(), IndexType::get(builder->getContext())); |       operandPack.getNumInputLoops(), IndexType::get(builder->getContext())); | ||||||
|   body->addArguments(body_args); |   body->addArguments(body_args); | ||||||
|   bodyRegion->push_back(body); |   bodyRegion->push_back(body); | ||||||
| 
 | 
 | ||||||
|  | @ -192,57 +168,31 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result, | ||||||
| void print(OpAsmPrinter& p, KrnlIterateOp& op) { | void print(OpAsmPrinter& p, KrnlIterateOp& op) { | ||||||
|   p << "krnl.iterate("; |   p << "krnl.iterate("; | ||||||
|   // Print optimized loops:
 |   // Print optimized loops:
 | ||||||
|   auto num_optimized_loops = op.getNumOptimizedLoops(); |   auto numOptimizedLoops = op.getNumOptimizedLoops(); | ||||||
|   p.printOperands(op.operand_begin(), op.operand_begin() + num_optimized_loops); |   p.printOperands(op.operand_begin(), op.operand_begin() + numOptimizedLoops); | ||||||
|   p << ") with ("; |   p << ") with ("; | ||||||
| 
 | 
 | ||||||
|   // Set up iterator to input loops:
 |   auto inductionVars = op.bodyRegion().begin()->getArguments(); | ||||||
|   auto num_input_loops = op.getNumInputLoops(); |   auto boundItr = | ||||||
|   auto input_loop_begin = op.operand_begin() + num_optimized_loops; |       op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName()) | ||||||
|  |           .getValue() | ||||||
|  |           .begin(); | ||||||
|  |   auto operandItr = op.operand_begin() + numOptimizedLoops; | ||||||
| 
 | 
 | ||||||
|   // Set up iterators to operand bounds.
 |   std::string delimiter; | ||||||
|   auto next_operand_bound = input_loop_begin + num_input_loops; |   for (auto& var : inductionVars) { | ||||||
| 
 |     p << delimiter; | ||||||
|   // Function to print a lower or upper bound.
 |     p.printOperand(*operandItr++); | ||||||
|   auto print_bound = [&](ArrayRef<Attribute> bound_types, size_t idx) { |  | ||||||
|     IntegerAttr type = bound_types[idx].dyn_cast<IntegerAttr>(); |  | ||||||
|     if (type.getValue().getSExtValue() == 0) { |  | ||||||
|       // Bound is an integer attribute.
 |  | ||||||
|       auto bound_idx = idx / 2; |  | ||||||
|       auto is_ub = idx % 2; |  | ||||||
|       IntegerAttr bound = op.getAttrOfType<IntegerAttr>( |  | ||||||
|           KrnlIterateOp::getBoundAttrName(bound_idx, is_ub)); |  | ||||||
|       p << bound.getValue().getSExtValue(); |  | ||||||
|     } else { |  | ||||||
|       // Bound is an operand.
 |  | ||||||
|       p.printOperand(*next_operand_bound); |  | ||||||
|       next_operand_bound = std::next(next_operand_bound); |  | ||||||
|     } |  | ||||||
|   }; |  | ||||||
| 
 |  | ||||||
|   auto induction_variables = op.bodyRegion().front().getArguments(); |  | ||||||
|   ArrayRef<Attribute> bound_types = |  | ||||||
|       op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundTypesAttrName()) |  | ||||||
|           .getValue(); |  | ||||||
| 
 |  | ||||||
|   // Print input loop operands, induction variables and their ranges.
 |  | ||||||
|   for (size_t i = 0; i < num_input_loops; i++) { |  | ||||||
|     if (i != 0) |  | ||||||
|       p << ", "; |  | ||||||
| 
 |  | ||||||
|     p.printOperand(*std::next(input_loop_begin, i)); |  | ||||||
|     p << " -> "; |     p << " -> "; | ||||||
| 
 |     p.printOperand(var); | ||||||
|     // Print induction variable block argument.
 |  | ||||||
|     p.printOperand(induction_variables[i]); |  | ||||||
|     p << " = "; |     p << " = "; | ||||||
| 
 |     onnf::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "max", p); | ||||||
|     print_bound(bound_types, 2 * i);  // Print lower bound.
 |  | ||||||
|     p << " to "; |     p << " to "; | ||||||
|     print_bound(bound_types, 2 * i + 1);  // Print upper bound.
 |     onnf::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "min", p); | ||||||
|  |     delimiter = ", "; | ||||||
|   } |   } | ||||||
|   p << ")"; |  | ||||||
| 
 | 
 | ||||||
|  |   p << ")"; | ||||||
|   p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false, |   p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false, | ||||||
|       /*printBlockTerminators=*/false); |       /*printBlockTerminators=*/false); | ||||||
| } | } | ||||||
|  | @ -250,80 +200,109 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) { | ||||||
| ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { | ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { | ||||||
|   auto builder = parser.getBuilder(); |   auto builder = parser.getBuilder(); | ||||||
|   auto context = builder.getContext(); |   auto context = builder.getContext(); | ||||||
|   onnf::KrnlDialectOperandParser operand_parser(parser); |   onnf::KrnlDialectOperandParser operandParser(parser); | ||||||
| 
 | 
 | ||||||
|   // Parse optimized loops:
 |   // Parse optimized loops:
 | ||||||
|   SmallVector<OpAsmParser::OperandType, 4> num_optimized_loops; |   SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs; | ||||||
|   if (parser.parseOperandList( |   if (parser.parseOperandList( | ||||||
|           num_optimized_loops, OpAsmParser::Delimiter::Paren) || |           optimizedLoopRefs, OpAsmParser::Delimiter::Paren) || | ||||||
|       parser.resolveOperands(num_optimized_loops, |       parser.resolveOperands(optimizedLoopRefs, | ||||||
|           LoopType::get(result.getContext()), result.operands)) |           LoopType::get(result.getContext()), result.operands)) | ||||||
|     return failure(); |     return failure(); | ||||||
| 
 | 
 | ||||||
|   // Record how many optimized loops did we parse.
 |   // Record how many optimized loops did we parse.
 | ||||||
|   result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(), |   result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(), | ||||||
|       builder.getI64IntegerAttr(num_optimized_loops.size())); |       builder.getI64IntegerAttr(optimizedLoopRefs.size())); | ||||||
| 
 | 
 | ||||||
|   // Parse input loops and their lower and upper bounds.
 |   // Parse input loops and their lower and upper bounds.
 | ||||||
|   SmallVector<OpAsmParser::OperandType, 4> in_loop_refs, induction_var_refs; |   SmallVector<OpAsmParser::OperandType, 4> inductionVarRefs; | ||||||
|   SmallVector<Value*, 4> in_loop_operands, operand_bounds; |   SmallVector<Attribute, 4> boundMaps; | ||||||
|   SmallVector<Attribute, 4> bound_types; |  | ||||||
|   SmallVector<IntegerAttr, 4> const_bounds; |  | ||||||
| 
 | 
 | ||||||
|   if (parser.parseKeyword("with") || parser.parseLParen()) |   if (parser.parseKeyword("with") || parser.parseLParen()) | ||||||
|     return failure(); |     return failure(); | ||||||
| 
 | 
 | ||||||
|   // A function to parse a lower or upper bound.
 |   // A function to parse a lower or upper bound.
 | ||||||
|   auto parse_bound = [&result, &builder, &operand_parser, &parser, &bound_types, |   auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps]( | ||||||
|                          &operand_bounds, &const_bounds]( |                         bool isUpper) -> ParseResult { | ||||||
|                          bool is_ub, size_t bound_pair_count) -> ParseResult { |     // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
 | ||||||
|  |     // the map has multiple results.
 | ||||||
|  |     bool failedToParsedMinMax = | ||||||
|  |         failed(parser.parseOptionalKeyword(isUpper ? "min" : "max")); | ||||||
|  | 
 | ||||||
|     // Try parse an SSA operand.
 |     // Try parse an SSA operand.
 | ||||||
|     Value* bound; |     if (succeeded(operandParser.ParseOptionalOperand( | ||||||
|     operand_parser.ParseOptionalOperand(builder.getIndexType(), bound); |             builder.getIndexType(), result.operands))) { | ||||||
|  |       AffineMap map = builder.getSymbolIdentityMap(); | ||||||
|  |       boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||||
|  |       return success(); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     if (bound != nullptr) { |     // Bound is not an SSA id, then it must be an integer.
 | ||||||
|       // Parsed an SSA id as bound.
 |     // Parse an integer constant attribute.
 | ||||||
|       operand_bounds.emplace_back(bound); |     // Get the attribute location.
 | ||||||
|       // Record bound_type as an operand type.
 |     llvm::SMLoc attrLoc = parser.getCurrentLocation(); | ||||||
|       bound_types.emplace_back(builder.getI32IntegerAttr(0)); |     Attribute boundAttr; | ||||||
|     } else { |     llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer; | ||||||
|       // Bound is not an SSA id, then it must be an integer.
 |     if (parser.parseAttribute( | ||||||
|       // Parse an integer constant attribute.
 |             boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer)) | ||||||
|       IntegerAttr boundAttr; |       return failure(); | ||||||
|       if (parser.parseAttribute(boundAttr, builder.getIndexType(), | 
 | ||||||
|               KrnlIterateOp::getBoundAttrName(bound_pair_count, is_ub), |     if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) { | ||||||
|               result.attributes)) |       unsigned currentNumOperands = result.operands.size(); | ||||||
|  |       unsigned numDims = 0; | ||||||
|  |       if (parseDimAndSymbolList(parser, result.operands, numDims)) | ||||||
|         return failure(); |         return failure(); | ||||||
|       const_bounds.emplace_back( |  | ||||||
|           builder.getIntegerAttr(builder.getIndexType(), boundAttr.getValue())); |  | ||||||
| 
 | 
 | ||||||
|       // Record that the bound_type is a constant integer attribute.
 |       auto map = affineMapAttr.getValue(); | ||||||
|       bound_types.emplace_back(builder.getI32IntegerAttr(1)); |       if (map.getNumDims() != numDims) | ||||||
|  |         return parser.emitError(parser.getNameLoc(), | ||||||
|  |             "dim operand count and integer set dim count must match"); | ||||||
|  | 
 | ||||||
|  |       unsigned numDimAndSymbolOperands = | ||||||
|  |           result.operands.size() - currentNumOperands; | ||||||
|  |       if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) | ||||||
|  |         return parser.emitError(parser.getNameLoc(), | ||||||
|  |             "symbol operand count and integer set symbol count must match"); | ||||||
|  | 
 | ||||||
|  |       // If the map has multiple results, make sure that we parsed the min/max
 | ||||||
|  |       // prefix.
 | ||||||
|  |       if (map.getNumResults() > 1 && failedToParsedMinMax) { | ||||||
|  |         if (isUpper) | ||||||
|  |           return parser.emitError(attrLoc, | ||||||
|  |               "upper loop bound affine map with multiple " | ||||||
|  |               "results requires 'min' prefix"); | ||||||
|  |         return parser.emitError(attrLoc, | ||||||
|  |             "lower loop bound affine mapwith " | ||||||
|  |             "multiple results requires 'max' prefix"); | ||||||
|  |       } | ||||||
|  |       boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||||
|  |       return success(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) { | ||||||
|  |       AffineMap map = | ||||||
|  |           builder.getConstantAffineMap(integerAttr.getValue().getSExtValue()); | ||||||
|  |       boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
|   bool keep_parsing;            // Do we keep parsing loops/bounds?
 |   bool keepParsing;  // Do we keep parsing loops/bounds?
 | ||||||
|   size_t bound_pair_count = 0;  // Record the number of bound pairs parsed.
 |  | ||||||
|   do { |   do { | ||||||
|     // Parse an input loop operand;
 |     // Parse an input loop operand;
 | ||||||
|     Value* in_loop_operand; |     operandParser.ParseOperand(LoopType::get(context), result.operands); | ||||||
|     operand_parser.ParseOperand(LoopType::get(context), in_loop_operand); |  | ||||||
|     in_loop_operands.emplace_back(in_loop_operand); |  | ||||||
| 
 |  | ||||||
|     parser.parseArrow(); |     parser.parseArrow(); | ||||||
| 
 | 
 | ||||||
|     // Parse induction variable.
 |     // Parse induction variable.
 | ||||||
|     OpAsmParser::OperandType induction_var; |     OpAsmParser::OperandType inductionVar; | ||||||
|     if (parser.parseRegionArgument(induction_var) || parser.parseEqual()) |     if (parser.parseRegionArgument(inductionVar) || parser.parseEqual()) | ||||||
|       return failure(); |       return failure(); | ||||||
|     induction_var_refs.emplace_back(induction_var); |     inductionVarRefs.emplace_back(inductionVar); | ||||||
| 
 | 
 | ||||||
|     // Parse bound par (min to max).
 |     // Parse bound par (min to max).
 | ||||||
|     if (parse_bound(false, bound_pair_count) || parser.parseKeyword("to") || |     if (parseBound(/*isUpper=*/false) || parser.parseKeyword("to") || | ||||||
|         parse_bound(true, bound_pair_count)) |         parseBound(/*isUpper=*/true)) | ||||||
|       return failure(); |       return failure(); | ||||||
| 
 | 
 | ||||||
|     bound_pair_count++; |  | ||||||
|     // We may fail to parse a comma if an operand bound is followed by
 |     // We may fail to parse a comma if an operand bound is followed by
 | ||||||
|     // a comma and the next input loop operand, in which case
 |     // a comma and the next input loop operand, in which case
 | ||||||
|     // the entire "{operand bound}, {input_loop_operand}" sequence will
 |     // the entire "{operand bound}, {input_loop_operand}" sequence will
 | ||||||
|  | @ -331,33 +310,19 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { | ||||||
|     parser.parseOptionalComma(); |     parser.parseOptionalComma(); | ||||||
| 
 | 
 | ||||||
|     // If we don't see a RParen token, we keep parsing.
 |     // If we don't see a RParen token, we keep parsing.
 | ||||||
|     keep_parsing = failed(parser.parseOptionalRParen()); |     keepParsing = failed(parser.parseOptionalRParen()); | ||||||
|   } while (keep_parsing); |   } while (keepParsing); | ||||||
| 
 | 
 | ||||||
|   // At this point, there shouldn't be any operands left to parse.
 |   // At this point, there shouldn't be any operands left to parse.
 | ||||||
|   if (operand_parser.has_operand_left()) |   if (operandParser.hasOperandLeft()) | ||||||
|     return parser.emitError(parser.getCurrentLocation()); |     return parser.emitError(parser.getCurrentLocation()); | ||||||
|  |   result.addAttribute( | ||||||
|  |       KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps)); | ||||||
| 
 | 
 | ||||||
|   // Record how many input loops did we parse.
 |  | ||||||
|   result.addOperands(in_loop_operands); |  | ||||||
|   result.addAttribute(KrnlIterateOp::getNumInputLoopsAttrName(), |  | ||||||
|       builder.getI64IntegerAttr(in_loop_operands.size())); |  | ||||||
| 
 |  | ||||||
|   // Add operand bounds to the list of operands of current operation.
 |  | ||||||
|   result.addOperands(operand_bounds); |  | ||||||
| 
 |  | ||||||
|   // A list of 2N elements where the (2n) and (2n+1) th element specifies
 |  | ||||||
|   // whether the lower and upper bound of the n'th induction variable is stored
 |  | ||||||
|   // as an operand or as an attribute. N being the number of input loops
 |  | ||||||
|   // specified in this krnl.iterate operation.
 |  | ||||||
|   result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(), |  | ||||||
|       builder.getArrayAttr(bound_types)); |  | ||||||
| 
 |  | ||||||
|   // Parse the schedule body region.
 |  | ||||||
|   Region* region = result.addRegion(); |   Region* region = result.addRegion(); | ||||||
|   SmallVector<Type, 4> induction_var_types( |   SmallVector<Type, 4> inductionVarTypes( | ||||||
|       induction_var_refs.size(), builder.getIndexType()); |       inductionVarRefs.size(), builder.getIndexType()); | ||||||
|   if (parser.parseRegion(*region, induction_var_refs, induction_var_types)) |   if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes)) | ||||||
|     return failure(); |     return failure(); | ||||||
| 
 | 
 | ||||||
|   // Ensure iterate region is closed off with krnl.terminate.
 |   // Ensure iterate region is closed off with krnl.terminate.
 | ||||||
|  |  | ||||||
|  | @ -14,6 +14,7 @@ | ||||||
| #include "mlir/IR/OpDefinition.h" | #include "mlir/IR/OpDefinition.h" | ||||||
| #include "mlir/IR/StandardTypes.h" | #include "mlir/IR/StandardTypes.h" | ||||||
| 
 | 
 | ||||||
|  | #include "src/compiler/dialect/krnl/krnl_helper.hpp" | ||||||
| #include "src/compiler/dialect/krnl/krnl_types.hpp" | #include "src/compiler/dialect/krnl/krnl_types.hpp" | ||||||
| 
 | 
 | ||||||
| namespace mlir { | namespace mlir { | ||||||
|  |  | ||||||
|  | @ -8,6 +8,7 @@ | ||||||
| 
 | 
 | ||||||
| include "mlir/IR/OpBase.td" | include "mlir/IR/OpBase.td" | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def Krnl_Dialect : Dialect { | def Krnl_Dialect : Dialect { | ||||||
|   let name = "krnl"; |   let name = "krnl"; | ||||||
|   let cppNamespace = ""; |   let cppNamespace = ""; | ||||||
|  | @ -119,17 +120,14 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> { | ||||||
| 
 | 
 | ||||||
|   let builders = [ |   let builders = [ | ||||||
|     OpBuilder<"Builder *builder, OperationState &result, " |     OpBuilder<"Builder *builder, OperationState &result, " | ||||||
|               "ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops, " |               "KrnlIterateOperandPack operandPack"> | ||||||
|               "ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds, " |  | ||||||
|               "ArrayRef<int> bound_types"> |  | ||||||
|   ]; |   ]; | ||||||
| 
 | 
 | ||||||
|   let extraClassDeclaration = [{ |   let extraClassDeclaration = [{ | ||||||
| 
 |     // In krnl.iterate operation, operands are stored as such | ||||||
|     // In krnl.iterate operation, three types of SSA values are stored: |  | ||||||
|     // - Optimized krnl.loops. |     // - Optimized krnl.loops. | ||||||
|     // - Input krnl.loops. |     // - Input krnl.loops and their operand bounds. (TODO(Tian) explain better how we store them). | ||||||
|     // - SSA value based induction variable bound (parametric bound). | 
 | ||||||
|     // We record the number of optimized and input loops to separate these three |     // We record the number of optimized and input loops to separate these three | ||||||
|     // group of operands out. |     // group of operands out. | ||||||
|     static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; } |     static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; } | ||||||
|  | @ -143,32 +141,8 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> { | ||||||
|       return num_optimized_loops; |       return num_optimized_loops; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     static StringRef getNumInputLoopsAttrName() { return "num_input_loops"; } |     // Get name of the attribute for storing bound represented using affine maps. | ||||||
| 
 |     static StringRef getBoundsAttrName() { return "bounds"; } | ||||||
|     int64_t getNumInputLoops() { |  | ||||||
|       auto num_loops = |  | ||||||
|         getAttrOfType<IntegerAttr>( |  | ||||||
|                 getNumInputLoopsAttrName()) |  | ||||||
|               .getValue() |  | ||||||
|               .getSExtValue(); |  | ||||||
|       return num_loops; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // Constant bounds are stored here as a list attribute. |  | ||||||
|     static StringRef getConstantBoundsAttrName() { return "constant_bounds"; } |  | ||||||
| 
 |  | ||||||
|     // Store type of each bound as three types: |  | ||||||
|     // - 0 = constant attribute. |  | ||||||
|     // - 1 = operand type. |  | ||||||
|     // - 2 = affine maps (TODO). |  | ||||||
|     static StringRef getBoundTypesAttrName() { return "bound_types"; } |  | ||||||
| 
 |  | ||||||
|     // Get dynamic attribute name for the i-th lower and upper bound. |  | ||||||
|     static std::string getBoundAttrName(int64_t i, bool is_ub) { |  | ||||||
|       std::string bound_type = is_ub ? "_ub" : "_lb"; |  | ||||||
|       std::string bound_idx = std::to_string(i); |  | ||||||
|       return "__bound_" + bound_idx + bound_type; |  | ||||||
|     } |  | ||||||
|     }]; |     }]; | ||||||
| 
 | 
 | ||||||
|   let printer = [{ return ::print(p, *this); }]; |   let printer = [{ return ::print(p, *this); }]; | ||||||
|  |  | ||||||
|  | @ -1,52 +0,0 @@ | ||||||
| //===------------------ parser_helper.cpp - MLIR Operations ---------------===//
 |  | ||||||
| //
 |  | ||||||
| // Copyright 2019 The IBM Research Authors.
 |  | ||||||
| //
 |  | ||||||
| // =============================================================================
 |  | ||||||
| //
 |  | ||||||
| //===----------------------------------------------------------------------===//
 |  | ||||||
| 
 |  | ||||||
| #include "parser_helper.hpp" |  | ||||||
| 
 |  | ||||||
| #include "src/compiler/dialect/krnl/krnl_ops.hpp" |  | ||||||
| 
 |  | ||||||
| namespace onnf { |  | ||||||
| 
 |  | ||||||
| mlir::ParseResult KrnlDialectOperandParser::ParseOptionalOperand( |  | ||||||
|     mlir::Type operand_type, mlir::Value*& operand) { |  | ||||||
|   // If operand queue is empty, parse more operands and cache them.
 |  | ||||||
|   if (_operand_ref_queue.empty()) { |  | ||||||
|     // Parse operand types:
 |  | ||||||
|     llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> operand_refs; |  | ||||||
|     _parser.parseOperandList(operand_refs); |  | ||||||
| 
 |  | ||||||
|     // Record operands:
 |  | ||||||
|     for (auto& operand_ref : operand_refs) |  | ||||||
|       _operand_ref_queue.emplace(operand_ref); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   // If we parsed some operand reference(s), resolve the ref to an operand:
 |  | ||||||
|   if (!_operand_ref_queue.empty()) { |  | ||||||
|     auto operand_ref = _operand_ref_queue.front(); |  | ||||||
|     _operand_ref_queue.pop(); |  | ||||||
| 
 |  | ||||||
|     llvm::SmallVector<mlir::Value*, 1> operands; |  | ||||||
|     _parser.resolveOperand(operand_ref, operand_type, operands); |  | ||||||
|     operand = operands.front(); |  | ||||||
|     return mlir::success(); |  | ||||||
|   } else { |  | ||||||
|     operand = nullptr; |  | ||||||
|     return mlir::failure(); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| mlir::ParseResult KrnlDialectOperandParser::ParseOperand( |  | ||||||
|     mlir::Type operand_type, mlir::Value*& operand) { |  | ||||||
|   ParseOptionalOperand(operand_type, operand); |  | ||||||
|   if (operand == nullptr) |  | ||||||
|     return _parser.emitError( |  | ||||||
|         _parser.getCurrentLocation(), "Expecting an operand."); |  | ||||||
|   return mlir::success(); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| }  // namespace onnf
 |  | ||||||
|  | @ -1,46 +0,0 @@ | ||||||
| //===------------------ parser_helper.hpp - MLIR Operations ---------------===//
 |  | ||||||
| //
 |  | ||||||
| // Copyright 2019 The IBM Research Authors.
 |  | ||||||
| //
 |  | ||||||
| // =============================================================================
 |  | ||||||
| //
 |  | ||||||
| //===----------------------------------------------------------------------===//
 |  | ||||||
| 
 |  | ||||||
| #pragma once |  | ||||||
| 
 |  | ||||||
| #include <queue> |  | ||||||
| 
 |  | ||||||
| #include "mlir/IR/Builders.h" |  | ||||||
| #include "mlir/IR/Dialect.h" |  | ||||||
| #include "mlir/IR/OpDefinition.h" |  | ||||||
| #include "mlir/IR/OpImplementation.h" |  | ||||||
| #include "mlir/IR/StandardTypes.h" |  | ||||||
| 
 |  | ||||||
| namespace onnf { |  | ||||||
| 
 |  | ||||||
| class KrnlDialectOperandParser { |  | ||||||
|  public: |  | ||||||
|   KrnlDialectOperandParser(mlir::OpAsmParser& parser) |  | ||||||
|       : _parser(parser), _builder(parser.getBuilder()){}; |  | ||||||
| 
 |  | ||||||
|   // Parse an optional operand.
 |  | ||||||
|   mlir::ParseResult ParseOptionalOperand( |  | ||||||
|       mlir::Type operand_type, mlir::Value*& operand); |  | ||||||
| 
 |  | ||||||
|   // Parse a required operand.
 |  | ||||||
|   mlir::ParseResult ParseOperand( |  | ||||||
|       mlir::Type operand_type, mlir::Value*& operand); |  | ||||||
| 
 |  | ||||||
|   // Do we have more operands to parse?
 |  | ||||||
|   bool has_operand_left() { return !_operand_ref_queue.empty(); } |  | ||||||
| 
 |  | ||||||
|  private: |  | ||||||
|   mlir::OpAsmParser& _parser; |  | ||||||
| 
 |  | ||||||
|   mlir::Builder& _builder; |  | ||||||
| 
 |  | ||||||
|   // A queue storing the parsed SSA id references.
 |  | ||||||
|   std::queue<mlir::OpAsmParser::OperandType> _operand_ref_queue; |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| }  // namespace onnf
 |  | ||||||
|  | @ -16,6 +16,7 @@ | ||||||
| #include "mlir/Pass/Pass.h" | #include "mlir/Pass/Pass.h" | ||||||
| #include "mlir/Transforms/DialectConversion.h" | #include "mlir/Transforms/DialectConversion.h" | ||||||
| 
 | 
 | ||||||
|  | #include "src/compiler/dialect/krnl/krnl_helper.hpp" | ||||||
| #include "src/compiler/dialect/krnl/krnl_ops.hpp" | #include "src/compiler/dialect/krnl/krnl_ops.hpp" | ||||||
| #include "src/compiler/dialect/onnx/onnx_ops.hpp" | #include "src/compiler/dialect/onnx/onnx_ops.hpp" | ||||||
| 
 | 
 | ||||||
|  | @ -43,13 +44,12 @@ static MemRefType convertTensorToMemRef(TensorType type) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// Insert an allocation and deallocation for the given MemRefType.
 | /// Insert an allocation and deallocation for the given MemRefType.
 | ||||||
| static Value* insertAllocAndDealloc( | static Value* insertAllocAndDealloc(MemRefType type, Location loc, | ||||||
|     MemRefType type, Location loc, PatternRewriter& rewriter, |     PatternRewriter& rewriter, Value* oldMemRef = nullptr) { | ||||||
|     Value *oldMemRef = nullptr) { |  | ||||||
|   // Put together alloc operands for any dynamic dimensions of the memref.
 |   // Put together alloc operands for any dynamic dimensions of the memref.
 | ||||||
|   AllocOp alloc; |   AllocOp alloc; | ||||||
|   if (oldMemRef) { |   if (oldMemRef) { | ||||||
|     SmallVector<Value *, 4> allocOperands; |     SmallVector<Value*, 4> allocOperands; | ||||||
|     auto memRefShape = type.getShape(); |     auto memRefShape = type.getShape(); | ||||||
|     for (int i = 0; i < memRefShape.size(); ++i) |     for (int i = 0; i < memRefShape.size(); ++i) | ||||||
|       if (memRefShape[i] < 0) |       if (memRefShape[i] < 0) | ||||||
|  | @ -95,7 +95,7 @@ struct ONNXAddOpLowering : public ConversionPattern { | ||||||
|     // dimensions with the result at this pre-optimization phase.
 |     // dimensions with the result at this pre-optimization phase.
 | ||||||
|     // TODO: verify that dimensions match.
 |     // TODO: verify that dimensions match.
 | ||||||
|     // TODO: can the dimension of the result differ after optimizations?
 |     // TODO: can the dimension of the result differ after optimizations?
 | ||||||
|     Value *alloc; |     Value* alloc; | ||||||
|     if (hasAllConstantDimensions(memRefType)) |     if (hasAllConstantDimensions(memRefType)) | ||||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter); |       alloc = insertAllocAndDealloc(memRefType, loc, rewriter); | ||||||
|     else |     else | ||||||
|  | @ -122,33 +122,22 @@ struct ONNXAddOpLowering : public ConversionPattern { | ||||||
|     } |     } | ||||||
|     Block& optimizationBlock = optimizedLoopsOp.region().front(); |     Block& optimizationBlock = optimizedLoopsOp.region().front(); | ||||||
| 
 | 
 | ||||||
|  |     KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); | ||||||
|     // Iterate over the loop nest.
 |     // Iterate over the loop nest.
 | ||||||
|     // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape
 |     // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape
 | ||||||
|     // to KrnlIterateOp instead.
 |     // to KrnlIterateOp instead.
 | ||||||
|     SmallVector<Value*, 8> operandBounds; |  | ||||||
|     SmallVector<int64_t, 8> constBounds; |  | ||||||
|     SmallVector<int, 16> boundTypes; |  | ||||||
|     for (int i = 0; i < rank; ++i) { |     for (int i = 0; i < rank; ++i) { | ||||||
|       if (memRefShape[i] < 0) { |       if (memRefShape[i] < 0) { | ||||||
|         // This is a dynamic value, hence use operands.
 |         pack.pushConstantBound(0); | ||||||
|         // Lower bound
 |         pack.pushOperandBound( | ||||||
|         constBounds.push_back(0); |  | ||||||
|         boundTypes.push_back(0); |  | ||||||
|         // Upper bound
 |  | ||||||
|         operandBounds.push_back( |  | ||||||
|             rewriter.create<DimOp>(loc, operands[0], i).getResult()); |             rewriter.create<DimOp>(loc, operands[0], i).getResult()); | ||||||
|         boundTypes.push_back(1); |  | ||||||
|       } else { |       } else { | ||||||
|         // Lower bound
 |         pack.pushConstantBound(0); | ||||||
|         constBounds.push_back(0); |         pack.pushConstantBound(memRefShape[i]); | ||||||
|         boundTypes.push_back(0); |  | ||||||
|         // Upper bound
 |  | ||||||
|         constBounds.push_back(memRefShape[i]); |  | ||||||
|         boundTypes.push_back(0); |  | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|     auto iterateOp = rewriter.create<KrnlIterateOp>(loc, originalLoops, | 
 | ||||||
|         optimizedLoops, operandBounds, constBounds, boundTypes); |     auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack); | ||||||
|     Block& iterationBlock = iterateOp.bodyRegion().front(); |     Block& iterationBlock = iterateOp.bodyRegion().front(); | ||||||
| 
 | 
 | ||||||
|     // Now perform the insertions into the body of the
 |     // Now perform the insertions into the body of the
 | ||||||
|  | @ -169,14 +158,12 @@ struct ONNXAddOpLowering : public ConversionPattern { | ||||||
|     SmallVector<Value*, 4> loopIVs; |     SmallVector<Value*, 4> loopIVs; | ||||||
|     for (auto arg : iterationBlock.getArguments()) |     for (auto arg : iterationBlock.getArguments()) | ||||||
|       loopIVs.push_back(arg); |       loopIVs.push_back(arg); | ||||||
|     auto loadedFirstVal = |     auto loadedFirstVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs); | ||||||
|         rewriter.create<LoadOp>(loc, operands[0], loopIVs); |     auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs); | ||||||
|     auto loadedSecondVal = |  | ||||||
|         rewriter.create<LoadOp>(loc, operands[1], loopIVs); |  | ||||||
| 
 | 
 | ||||||
|     // TODO: Choose type of the Add for now use the Float Add.
 |     // TODO: Choose type of the Add for now use the Float Add.
 | ||||||
|     auto addOpResult = rewriter.create<AddFOp>( |     auto addOpResult = | ||||||
|         loc, loadedFirstVal, loadedSecondVal); |         rewriter.create<AddFOp>(loc, loadedFirstVal, loadedSecondVal); | ||||||
| 
 | 
 | ||||||
|     // Store result in the resulting array.
 |     // Store result in the resulting array.
 | ||||||
|     rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs); |     rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs); | ||||||
|  | @ -209,8 +196,8 @@ struct TensorTypeConverter : public TypeConverter { | ||||||
|   /// inputs. Once unranked results can be handled gracefully this
 |   /// inputs. Once unranked results can be handled gracefully this
 | ||||||
|   /// override needs to be removed in favour of the original MLIR one.]
 |   /// override needs to be removed in favour of the original MLIR one.]
 | ||||||
|   bool isSignatureLegal(FunctionType funcType) { |   bool isSignatureLegal(FunctionType funcType) { | ||||||
|     return llvm::all_of(funcType.getInputs(), |     return llvm::all_of( | ||||||
|         [this](Type type) { return isLegal(type); }); |         funcType.getInputs(), [this](Type type) { return isLegal(type); }); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | @ -272,8 +259,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { | ||||||
|   // With the target and rewrite patterns defined, we can now attempt the
 |   // With the target and rewrite patterns defined, we can now attempt the
 | ||||||
|   // conversion. The conversion will signal failure if any of our `illegal`
 |   // conversion. The conversion will signal failure if any of our `illegal`
 | ||||||
|   // operations were not converted successfully.
 |   // operations were not converted successfully.
 | ||||||
|   if (failed(applyPartialConversion( |   if (failed(applyPartialConversion(module, target, patterns))) | ||||||
|           module, target, patterns))) |  | ||||||
|     signalPassFailure(); |     signalPassFailure(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -17,9 +17,10 @@ class Pass; | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<Pass> createShapeInferencePass(); | std::unique_ptr<Pass> createShapeInferencePass(); | ||||||
| 
 | 
 | ||||||
| /// Pass for lowering frontend dialects to Krnl IR dialect.
 | /// Add pass for lowering to Krnl IR.
 | ||||||
| std::unique_ptr<mlir::Pass> createLowerToKrnlPass(); | std::unique_ptr<Pass> createLowerToKrnlPass(); | ||||||
| 
 | 
 | ||||||
| // TODO: Add pass for lowering to LLVM IR.
 | /// Pass for lowering frontend dialects to Krnl IR dialect.
 | ||||||
|  | std::unique_ptr<Pass> createLowerKrnlPass(); | ||||||
| 
 | 
 | ||||||
| }  // end namespace mlir
 | }  // end namespace mlir
 | ||||||
|  |  | ||||||
|  | @ -75,7 +75,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { | ||||||
|     if (auto terminator_op = f.getBody().back().getTerminator()) { |     if (auto terminator_op = f.getBody().back().getTerminator()) { | ||||||
|       auto results = terminator_op->getOperandTypes(); |       auto results = terminator_op->getOperandTypes(); | ||||||
|       f.setType(FunctionType::get(f.getType().getInputs(), |       f.setType(FunctionType::get(f.getType().getInputs(), | ||||||
|                 std::vector<Type>(results.begin(), results.end()), f.getContext())); |           std::vector<Type>(results.begin(), results.end()), f.getContext())); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -6,8 +6,7 @@ target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT}) | ||||||
| 
 | 
 | ||||||
| target_link_libraries(onnf-opt compiler ${MLIRLibs}) | target_link_libraries(onnf-opt compiler ${MLIRLibs}) | ||||||
| whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs}) | whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs}) | ||||||
| whole_archive_link_onnf(onnf-opt onnf_lower_frontend) | whole_archive_link_onnf(onnf-opt onnf_transform onnf_lower_frontend onnf_shape_inference) | ||||||
| whole_archive_link_onnf(onnf-opt onnf_shape_inference) |  | ||||||
| 
 | 
 | ||||||
| # TODO: need to investigate how to whole-archive link compiler pass to onnf-opt. | # TODO: need to investigate how to whole-archive link compiler pass to onnf-opt. | ||||||
| target_link_libraries(onnf-opt compiler) | target_link_libraries(onnf-opt compiler) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,7 @@ | ||||||
|  | add_library(onnf_transform lower_krnl.cpp) | ||||||
|  | 
 | ||||||
|  | target_include_directories(onnf_transform | ||||||
|  |                            PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} | ||||||
|  |                                    ${ONNF_SRC_ROOT}) | ||||||
|  | target_link_libraries(onnf_transform ${MLIRLibs}) | ||||||
|  | add_dependencies(onnf_transform gen_krnl_ops) | ||||||
|  | @ -0,0 +1,161 @@ | ||||||
|  | #include "mlir/Dialect/AffineOps/AffineOps.h" | ||||||
|  | #include "mlir/Dialect/StandardOps/Ops.h" | ||||||
|  | #include "mlir/Pass/Pass.h" | ||||||
|  | #include "mlir/Transforms/DialectConversion.h" | ||||||
|  | 
 | ||||||
|  | #include "src/compiler/dialect/krnl/krnl_ops.hpp" | ||||||
|  | #include "src/compiler/pass/passes.hpp" | ||||||
|  | 
 | ||||||
|  | using namespace mlir; | ||||||
|  | 
 | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | // Krnl to Affine Rewrite Patterns: KrnlIterate operation.
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> { | ||||||
|  |   using OpRewritePattern<KrnlIterateOp>::OpRewritePattern; | ||||||
|  | 
 | ||||||
|  |   PatternMatchResult matchAndRewrite( | ||||||
|  |       KrnlIterateOp iterateOp, PatternRewriter& rewriter) const override { | ||||||
|  |     auto boundMapAttrs = | ||||||
|  |         iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName()) | ||||||
|  |             .getValue(); | ||||||
|  |     auto operandItr = | ||||||
|  |         iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops(); | ||||||
|  |     SmallVector<AffineForOp, 4> nestedForOps; | ||||||
|  |     for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) { | ||||||
|  |       // Consume input loop operand, currently do not do anything with it.
 | ||||||
|  |       operandItr++; | ||||||
|  | 
 | ||||||
|  |       // Organize operands into lower/upper bounds in affine.for ready formats.
 | ||||||
|  |       SmallVector<Value*, 4> lbOperands, ubOperands; | ||||||
|  |       AffineMap lbMap, ubMap; | ||||||
|  |       for (int boundType = 0; boundType < 2; boundType++) { | ||||||
|  |         auto& operands = boundType == 0 ? lbOperands : ubOperands; | ||||||
|  |         auto& map = boundType == 0 ? lbMap : ubMap; | ||||||
|  |         map = boundMapAttrs[boundIdx + boundType] | ||||||
|  |                   .cast<AffineMapAttr>() | ||||||
|  |                   .getValue(); | ||||||
|  |         operands.insert( | ||||||
|  |             operands.end(), operandItr, operandItr + map.getNumInputs()); | ||||||
|  |         std::advance(operandItr, map.getNumInputs()); | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       nestedForOps.emplace_back(rewriter.create<AffineForOp>( | ||||||
|  |           iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap)); | ||||||
|  |       rewriter.setInsertionPoint(nestedForOps.back().getBody(), | ||||||
|  |           nestedForOps.back().getBody()->begin()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Replace induction variable references from those introduced by a
 | ||||||
|  |     // single krnl.iterate to those introduced by multiple affine.for
 | ||||||
|  |     // operations.
 | ||||||
|  |     for (size_t i = 0; i < nestedForOps.size() - 1; i++) { | ||||||
|  |       auto iterateIV = iterateOp.bodyRegion().front().getArgument(0); | ||||||
|  |       auto forIV = nestedForOps[i].getBody()->getArgument(0); | ||||||
|  |       iterateIV->replaceAllUsesWith(forIV); | ||||||
|  |       iterateOp.bodyRegion().front().eraseArgument(0); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Pop krnl.iterate body region block arguments, leave the last one
 | ||||||
|  |     // for convenience (it'll be taken care of by region inlining).
 | ||||||
|  |     while (iterateOp.bodyRegion().front().getNumArguments() > 1) | ||||||
|  |       iterateOp.bodyRegion().front().eraseArgument(0); | ||||||
|  | 
 | ||||||
|  |     // Transfer krnl.iterate region to innermost for op.
 | ||||||
|  |     auto innermostForOp = nestedForOps.back(); | ||||||
|  |     innermostForOp.region().getBlocks().clear(); | ||||||
|  |     rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(), | ||||||
|  |         innermostForOp.region().end()); | ||||||
|  | 
 | ||||||
|  |     rewriter.eraseOp(iterateOp); | ||||||
|  |     return matchSuccess(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | // Krnl to Affine Rewrite Patterns: KrnlTerminator operation.
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> { | ||||||
|  |  public: | ||||||
|  |   using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern; | ||||||
|  | 
 | ||||||
|  |   PatternMatchResult matchAndRewrite( | ||||||
|  |       KrnlTerminatorOp op, PatternRewriter& rewriter) const override { | ||||||
|  |     rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op); | ||||||
|  |     return matchSuccess(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | // Krnl to Affine Rewrite Patterns: KrnlDefineLoops operation.
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> { | ||||||
|  |  public: | ||||||
|  |   using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern; | ||||||
|  | 
 | ||||||
|  |   PatternMatchResult matchAndRewrite( | ||||||
|  |       KrnlDefineLoopsOp op, PatternRewriter& rewriter) const override { | ||||||
|  |     rewriter.eraseOp(op); | ||||||
|  |     return matchSuccess(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | // Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation.
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> { | ||||||
|  |  public: | ||||||
|  |   using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern; | ||||||
|  | 
 | ||||||
|  |   PatternMatchResult matchAndRewrite( | ||||||
|  |       KrnlOptimizeLoopsOp op, PatternRewriter& rewriter) const override { | ||||||
|  |     rewriter.eraseOp(op); | ||||||
|  |     return matchSuccess(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | // KrnlToAffineLoweringPass
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | /// This is a partial lowering to affine loops of the krnl dialect operations.
 | ||||||
|  | /// At this stage the dialect will contain standard operations as well like
 | ||||||
|  | /// add and multiply, this pass will leave these operations intact.
 | ||||||
|  | namespace { | ||||||
|  | struct KrnlToAffineLoweringPass | ||||||
|  |     : public FunctionPass<KrnlToAffineLoweringPass> { | ||||||
|  |   void runOnFunction() final; | ||||||
|  | }; | ||||||
|  | }  // end anonymous namespace.
 | ||||||
|  | 
 | ||||||
|  | void KrnlToAffineLoweringPass::runOnFunction() { | ||||||
|  |   auto function = getFunction(); | ||||||
|  | 
 | ||||||
|  |   ConversionTarget target(getContext()); | ||||||
|  | 
 | ||||||
|  |   target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>(); | ||||||
|  |   // We expect IR to be free of Krnl Dialect Ops.
 | ||||||
|  |   target.addIllegalDialect<KrnlOpsDialect>(); | ||||||
|  | 
 | ||||||
|  |   OwningRewritePatternList patterns; | ||||||
|  |   patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering, | ||||||
|  |       KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext()); | ||||||
|  | 
 | ||||||
|  |   if (failed(applyPartialConversion(getFunction(), target, patterns))) | ||||||
|  |     signalPassFailure(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | std::unique_ptr<Pass> mlir::createLowerKrnlPass() { | ||||||
|  |   return std::make_unique<KrnlToAffineLoweringPass>(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static PassRegistration<KrnlToAffineLoweringPass> pass( | ||||||
|  |     "lower-krnl", "Lower Krnl dialect."); | ||||||
							
								
								
									
										16
									
								
								src/main.cpp
								
								
								
								
							
							
						
						
									
										16
									
								
								src/main.cpp
								
								
								
								
							|  | @ -28,6 +28,7 @@ | ||||||
| 
 | 
 | ||||||
| #include <boost/program_options.hpp> | #include <boost/program_options.hpp> | ||||||
| 
 | 
 | ||||||
|  | #include "llvm/Bitcode/BitcodeWriter.h" | ||||||
| #include "llvm/Support/FileUtilities.h" | #include "llvm/Support/FileUtilities.h" | ||||||
| #include "llvm/Support/Regex.h" | #include "llvm/Support/Regex.h" | ||||||
| #include "llvm/Support/SourceMgr.h" | #include "llvm/Support/SourceMgr.h" | ||||||
|  | @ -38,6 +39,8 @@ | ||||||
| #include "src/compiler/pass/passes.hpp" | #include "src/compiler/pass/passes.hpp" | ||||||
| 
 | 
 | ||||||
| #include "mlir/Analysis/Verifier.h" | #include "mlir/Analysis/Verifier.h" | ||||||
|  | #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" | ||||||
|  | #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" | ||||||
| #include "mlir/ExecutionEngine/ExecutionEngine.h" | #include "mlir/ExecutionEngine/ExecutionEngine.h" | ||||||
| #include "mlir/ExecutionEngine/OptUtils.h" | #include "mlir/ExecutionEngine/OptUtils.h" | ||||||
| #include "mlir/IR/MLIRContext.h" | #include "mlir/IR/MLIRContext.h" | ||||||
|  | @ -125,7 +128,20 @@ int main(int ac, char* av[]) { | ||||||
|   pm.addPass(mlir::createShapeInferencePass()); |   pm.addPass(mlir::createShapeInferencePass()); | ||||||
|   pm.addPass(mlir::createCanonicalizerPass()); |   pm.addPass(mlir::createCanonicalizerPass()); | ||||||
|   pm.addPass(mlir::createLowerToKrnlPass()); |   pm.addPass(mlir::createLowerToKrnlPass()); | ||||||
|  |   pm.addPass(mlir::createLowerKrnlPass()); | ||||||
|  |   pm.addPass(mlir::createLowerAffinePass()); | ||||||
|  |   pm.addPass(mlir::createLowerToCFGPass()); | ||||||
|  |   pm.addPass(mlir::createLowerToLLVMPass()); | ||||||
|  |   pm.addPass(mlir::createCanonicalizerPass()); | ||||||
|   pm.run(*module); |   pm.run(*module); | ||||||
| 
 | 
 | ||||||
|  |   // Write LLVM bitcode to disk.
 | ||||||
|  |   std::error_code EC; | ||||||
|  |   llvm::raw_fd_ostream moduleBitcodeStream( | ||||||
|  |       "model.bc", EC, llvm::sys::fs::F_None); | ||||||
|  |   llvm::WriteBitcodeToFile( | ||||||
|  |       *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); | ||||||
|  |   moduleBitcodeStream.flush(); | ||||||
|  | 
 | ||||||
|   return 0; |   return 0; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,75 @@ | ||||||
|  | // RUN: onnf-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s | ||||||
|  | // RUN: onnf-opt %s | FileCheck %s | ||||||
|  | 
 | ||||||
|  | // GENERIC-DAG: #{{.*}} = () -> (0) | ||||||
|  | // GENERIC-DAG: #{{.*}} = () -> (10) | ||||||
|  | // GENERIC-DAG: #{{.*}} = () -> (1) | ||||||
|  | // GENERIC-DAG: #{{.*}} = () -> (11) | ||||||
|  | // GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 - d1) | ||||||
|  | // GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 + d1) | ||||||
|  | 
 | ||||||
|  | func @simple_iterate(%N : index) { | ||||||
|  |   %ii, %ij, %ik = krnl.define_loops 3 | ||||||
|  |   %oi, %oj, %ok = krnl.optimize_loops  { | ||||||
|  |     krnl.return_loops %ii, %ij, %ik | ||||||
|  |   } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) | ||||||
|  | 
 | ||||||
|  |   // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||||
|  |   // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): | ||||||
|  |   // GENERIC-NEXT: "krnl.terminate"() : () -> () | ||||||
|  |   // GENERIC-NEXT: bounds = [#{{.*}}, #{{.*}}, #{{.*}}, #{{.*}}] | ||||||
|  | 
 | ||||||
|  |   // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 1 to 11) { | ||||||
|  |   krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 1 to 11) { | ||||||
|  | 
 | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||||
|  |   // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): | ||||||
|  |   // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) { | ||||||
|  |   krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 0 to 10) { | ||||||
|  |     // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}) ( { | ||||||
|  |     // GENERIC-NEXT: ^bb0(%{{.*}}: index): | ||||||
|  |     // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10) { | ||||||
|  |     krnl.iterate(%ok) with (%ik -> %k = 0 to 10) { | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||||
|  |   // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): | ||||||
|  |   // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to %{{.*}}, %{{.*}} -> %{{.*}} = 0 to 10) { | ||||||
|  |   krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to %N, %ij -> %j = 0 to 10) { | ||||||
|  | 
 | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func @affine_map_bound(%N : index) { | ||||||
|  |   %ii, %ij, %ik = krnl.define_loops 3 | ||||||
|  |   %oi, %oj, %ok = krnl.optimize_loops  { | ||||||
|  |     krnl.return_loops %ii, %ij, %ik | ||||||
|  |   } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) | ||||||
|  | 
 | ||||||
|  |   // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||||
|  |   // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): | ||||||
|  |   // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) { | ||||||
|  |   krnl.iterate(%oi, %oj) with (%ii -> %i = ()->(0)() to ()->(10)(), %ij -> %j = 0 to 10) { | ||||||
|  |     // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||||
|  |     // GENERIC-NEXT: ^bb0(%{{.*}}: index): | ||||||
|  |     // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = #{{.*}}(%{{.*}}, %{{.*}}) to #{{.*}}(%{{.*}}, %{{.*}})) { | ||||||
|  |     krnl.iterate(%ok) with (%ik -> %k = (d0, d1)->(d0 - d1)(%i, %j) to (d0, d1)->(d0 + d1)(%i, %j)) { | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { | ||||||
|  |     // GENERIC-NEXT: ^bb0(%{{.*}}: index): | ||||||
|  |     // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = max #map{{.*}}(%{{.*}}, %{{.*}}) to min #map{{.*}}(%{{.*}}, %{{.*}})[%{{.*}}]) { | ||||||
|  |     krnl.iterate(%ok) with (%ik -> %k = max (d0, d1)->(d0 - d1, 0)(%i, %j) to min (d0, d1)[s0]->(d0 + d1, s0)(%i, %j)[%N]) { | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   return | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue