[MLIR] Refactor Krnl Dialect and Krnl Dialect Lowering (#375)
* Store bounds as affine map attributes & check in test cases with generic printer * Upgrading MLIR MLIR is outdated on Buildbot, rebuilding a newer version. * work with new version of mlir * check-in parser tests * custom printer * nit * bug fix * enable custom asm printer test * enable custom asm printer test * more consistent variable naming * test max/min * variable naming scheme change to MLIR style * can lower krnl to llvm * kernel -> llvm * comments * bug fix * try fixing ci * fix ci * deactivate model test * fix lit test * nit * fix z buildbot
This commit is contained in:
parent
652ce4b7d4
commit
b2a1103915
|
@ -18,6 +18,8 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
|||
#TODO(eventually enable the following)
|
||||
#set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
include(MLIR.cmake)
|
||||
|
||||
add_subdirectory(third_party/onnx)
|
||||
add_subdirectory(third_party/benchmark)
|
||||
add_subdirectory(third_party/pybind11)
|
||||
|
@ -41,7 +43,6 @@ if(Boost_FOUND)
|
|||
include_directories(${Boost_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
include(MLIR.cmake)
|
||||
add_subdirectory(src/builder)
|
||||
add_subdirectory(src/compiler)
|
||||
add_subdirectory(src)
|
||||
|
|
92
MLIR.cmake
92
MLIR.cmake
|
@ -5,7 +5,7 @@ if(DEFINED ENV{LLVM_SRC})
|
|||
message(STATUS "LLVM_SRC " ${LLVM_SRC})
|
||||
else()
|
||||
message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: "
|
||||
${LLVM_SRC})
|
||||
${LLVM_SRC})
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "env variable LLVM_SRC not set")
|
||||
|
@ -18,7 +18,7 @@ if(DEFINED ENV{LLVM_BUILD})
|
|||
message(STATUS "LLVM_BUILD " ${LLVM_BUILD})
|
||||
else()
|
||||
message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: "
|
||||
${LLVM_BUILD})
|
||||
${LLVM_BUILD})
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "env variable LLVM_BUILD not set")
|
||||
|
@ -39,9 +39,9 @@ set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
|
|||
set(ONNF_LIT_TEST_BUILD_DIR ${CMAKE_BINARY_DIR}/test/mlir)
|
||||
|
||||
set(
|
||||
MLIR_INCLUDE_PATHS
|
||||
${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH}
|
||||
)
|
||||
MLIR_INCLUDE_PATHS
|
||||
${LLVM_SRC_INCLUDE_PATH};${LLVM_BIN_INCLUDE_PATH};${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH}
|
||||
)
|
||||
include_directories(${MLIR_INCLUDE_PATHS})
|
||||
|
||||
# Threading libraries required due to parallel pass execution.
|
||||
|
@ -49,9 +49,9 @@ find_package(Threads REQUIRED)
|
|||
|
||||
function(find_mlir_lib lib)
|
||||
find_library(${lib}
|
||||
NAMES ${lib}
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
NAMES ${lib}
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
endfunction(find_mlir_lib)
|
||||
|
||||
find_mlir_lib(MLIRAffineOps)
|
||||
|
@ -70,6 +70,10 @@ find_mlir_lib(MLIRTransforms)
|
|||
find_mlir_lib(MLIRTransformUtils)
|
||||
find_mlir_lib(MLIRSupport)
|
||||
find_mlir_lib(MLIROptMain)
|
||||
find_mlir_lib(MLIRTargetLLVMIRModuleTranslation)
|
||||
find_mlir_lib(MLIRTargetLLVMIR)
|
||||
find_mlir_lib(MLIRTransformUtils)
|
||||
find_mlir_lib(MLIRTranslation)
|
||||
find_mlir_lib(MLIRVectorOps)
|
||||
|
||||
find_mlir_lib(LLVMCore)
|
||||
|
@ -80,46 +84,52 @@ find_mlir_lib(LLVMRemarks)
|
|||
find_mlir_lib(LLVMIRReader)
|
||||
find_mlir_lib(LLVMTransformUtils)
|
||||
find_mlir_lib(LLVMBitstreamReader)
|
||||
find_mlir_lib(LLVMAnalysis)
|
||||
find_mlir_lib(LLVMBitWriter)
|
||||
find_mlir_lib(LLVMBitReader)
|
||||
find_mlir_lib(LLVMMC)
|
||||
find_mlir_lib(LLVMMCParser)
|
||||
find_mlir_lib(LLVMObject)
|
||||
find_mlir_lib(LLVMProfileData)
|
||||
find_mlir_lib(LLVMDemangle)
|
||||
|
||||
|
||||
set(MLIRLibsOnce
|
||||
LLVMAnalysis
|
||||
LLVMAsmParser
|
||||
LLVMBinaryFormat
|
||||
LLVMBitReader
|
||||
LLVMBitstreamReader
|
||||
LLVMBitWriter
|
||||
LLVMCore
|
||||
LLVMIRReader
|
||||
LLVMMC
|
||||
LLVMMCParser
|
||||
LLVMObject
|
||||
LLVMRemarks
|
||||
LLVMSupport
|
||||
LLVMTransformUtils
|
||||
LLVMProfileData
|
||||
LLVMDemangle
|
||||
MLIRAffineOps
|
||||
MLIRAffineToStandard
|
||||
MLIRAnalysis
|
||||
MLIRExecutionEngine
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRLoopOps
|
||||
MLIRLoopToStandard
|
||||
MLIROptMain
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRStandardOps
|
||||
MLIRStandardToLLVM
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIR
|
||||
MLIRTransforms
|
||||
MLIRAffineOps
|
||||
MLIRAffineToStandard
|
||||
MLIRAnalysis
|
||||
MLIRExecutionEngine
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRLoopToStandard
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRStandardOps
|
||||
MLIRStandardToLLVM
|
||||
MLIRTargetLLVMIR
|
||||
MLIRTargetLLVMIRModuleTranslation
|
||||
MLIRTransforms
|
||||
MLIRTransformUtils
|
||||
MLIRLoopOps
|
||||
MLIRSupport
|
||||
MLIROptMain
|
||||
LLVMCore
|
||||
LLVMSupport
|
||||
LLVMAsmParser
|
||||
LLVMIRReader
|
||||
LLVMTransformUtils
|
||||
LLVMBinaryFormat
|
||||
LLVMRemarks
|
||||
LLVMBitstreamReader)
|
||||
MLIRTranslation)
|
||||
|
||||
set(MLIRLibs
|
||||
${MLIRLibsOnce}
|
||||
|
@ -142,7 +152,7 @@ function(whole_archive_link target lib_dir)
|
|||
set(link_flags "${link_flags} -L${lib_dir} ")
|
||||
foreach(LIB ${ARGN})
|
||||
string(CONCAT link_flags ${link_flags}
|
||||
"-Wl,-force_load ${lib_dir}/lib${LIB}.a ")
|
||||
"-Wl,-force_load ${lib_dir}/lib${LIB}.a ")
|
||||
endforeach(LIB)
|
||||
elseif(MSVC)
|
||||
foreach(LIB ${ARGN})
|
||||
|
@ -170,20 +180,20 @@ function(whole_archive_link_onnf target)
|
|||
endfunction(whole_archive_link_onnf)
|
||||
|
||||
set(LLVM_CMAKE_DIR
|
||||
"${LLVM_BUILD}/lib/cmake/llvm"
|
||||
CACHE PATH "Path to LLVM cmake modules")
|
||||
"${LLVM_BUILD}/lib/cmake/llvm"
|
||||
CACHE PATH "Path to LLVM cmake modules")
|
||||
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
|
||||
include(AddLLVM)
|
||||
include(TableGen)
|
||||
|
||||
function(onnf_tablegen ofn)
|
||||
tablegen(MLIR
|
||||
${ARGV}
|
||||
"-I${MLIR_SRC_INCLUDE_PATH}"
|
||||
"-I${MLIR_BIN_INCLUDE_PATH}")
|
||||
${ARGV}
|
||||
"-I${MLIR_SRC_INCLUDE_PATH}"
|
||||
"-I${MLIR_BIN_INCLUDE_PATH}")
|
||||
set(TABLEGEN_OUTPUT
|
||||
${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
|
||||
PARENT_SCOPE)
|
||||
${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
|
||||
PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Import the pre-built mlir TableGen as an imported exetuable. It is required by
|
||||
|
@ -191,5 +201,5 @@ endfunction()
|
|||
# table gen utility itself can be detected and cause re-compilation of .td file.
|
||||
add_executable(mlir-tblgen IMPORTED)
|
||||
set_property(TARGET mlir-tblgen
|
||||
PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen)
|
||||
PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen)
|
||||
set(MLIR_TABLEGEN_EXE mlir-tblgen)
|
||||
|
|
|
@ -2,8 +2,9 @@ add_executable(onnf main.cpp)
|
|||
|
||||
target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES})
|
||||
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
||||
whole_archive_link_onnf(onnf onnf_transform)
|
||||
|
||||
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
||||
|
||||
install(TARGETS onnf DESTINATION bin)
|
||||
install(TARGETS onnf DESTINATION bin)
|
||||
|
|
|
@ -6,8 +6,8 @@ add_library(
|
|||
dialect/krnl/krnl_types.hpp
|
||||
dialect/onnx/onnx_ops.cpp
|
||||
dialect/onnx/onnx_ops.hpp
|
||||
dialect/krnl/parser_helper.cpp
|
||||
dialect/krnl/parser_helper.hpp
|
||||
dialect/krnl/krnl_helper.cpp
|
||||
dialect/krnl/krnl_helper.hpp
|
||||
pass/shape_inference_pass.cpp
|
||||
pass/shape_inference_interface.hpp
|
||||
dialect/onnx/onnxop.inc
|
||||
|
@ -82,4 +82,5 @@ target_include_directories(onnf_lower_frontend
|
|||
target_link_libraries(onnf_lower_frontend ${MLIRLibs})
|
||||
add_dependencies(onnf_lower_frontend gen_krnl_ops)
|
||||
|
||||
add_subdirectory(transform)
|
||||
add_subdirectory(tool)
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
|
||||
#include "krnl_helper.hpp"
|
||||
|
||||
namespace onnf {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||
const Type& operandType, Value*& operand) {
|
||||
// If operand queue is empty, parse more operands and cache them.
|
||||
if (_operandRefQueue.empty()) {
|
||||
// Parse operand types:
|
||||
llvm::SmallVector<OpAsmParser::OperandType, 2> operand_refs;
|
||||
_parser.parseOperandList(operand_refs);
|
||||
|
||||
// Record operands:
|
||||
for (auto& operand_ref : operand_refs)
|
||||
_operandRefQueue.emplace(operand_ref);
|
||||
}
|
||||
|
||||
// If we parsed some operand reference(s), resolve the ref to an operand:
|
||||
if (!_operandRefQueue.empty()) {
|
||||
auto operand_ref = _operandRefQueue.front();
|
||||
_operandRefQueue.pop();
|
||||
|
||||
llvm::SmallVector<Value*, 1> operands;
|
||||
_parser.resolveOperand(operand_ref, operandType, operands);
|
||||
operand = operands.front();
|
||||
return success();
|
||||
} else {
|
||||
operand = nullptr;
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
|
||||
Value* operand = nullptr;
|
||||
if (ParseOptionalOperand(operandType, operand))
|
||||
return failure();
|
||||
|
||||
operandList.emplace_back(operand);
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||
const Type& operandType, Value*& operand) {
|
||||
if (ParseOptionalOperand(operandType, operand))
|
||||
return _parser.emitError(
|
||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
|
||||
if (ParseOptionalOperand(operandType, operandList))
|
||||
return _parser.emitError(
|
||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
|
||||
unsigned numSymbols, OpAsmPrinter& p) {
|
||||
p << '(';
|
||||
p.printOperands(begin, begin + numDims);
|
||||
p << ')';
|
||||
|
||||
if (numSymbols) {
|
||||
p << '[';
|
||||
p.printOperands(begin + numDims, begin + numDims + numSymbols);
|
||||
p << ']';
|
||||
}
|
||||
|
||||
begin = std::next(begin, numDims + numSymbols);
|
||||
}
|
||||
|
||||
void printBound(AffineMapAttr boundMap,
|
||||
Operation::operand_iterator& boundOperandsBeg, const char* prefix,
|
||||
OpAsmPrinter& p) {
|
||||
AffineMap map = boundMap.getValue();
|
||||
|
||||
// Check if this bound should be printed using custom assembly form.
|
||||
// The decision to restrict printing custom assembly form to trivial cases
|
||||
// comes from the will to roundtrip MLIR binary -> text -> binary in a
|
||||
// lossless way.
|
||||
// Therefore, custom assembly form parsing and printing is only supported for
|
||||
// zero-operand constant maps and single symbol operand identity maps.
|
||||
if (map.getNumResults() == 1) {
|
||||
AffineExpr expr = map.getResult(0);
|
||||
|
||||
// Print constant bound.
|
||||
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
|
||||
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
|
||||
p << constExpr.getValue();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Print bound that consists of a single SSA symbol if the map is over a
|
||||
// single symbol.
|
||||
if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
|
||||
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
|
||||
p.printOperand(*(boundOperandsBeg++));
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Map has multiple results. Print 'min' or 'max' prefix.
|
||||
p << prefix << ' ';
|
||||
}
|
||||
|
||||
// Print the map and its operands.
|
||||
p << boundMap;
|
||||
printDimAndSymbolList(
|
||||
boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p);
|
||||
}
|
||||
} // namespace onnf
|
||||
|
||||
namespace mlir {
|
||||
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
|
||||
if (boundMaps.size() % 2 == 0)
|
||||
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
||||
AffineMap map = builder.getConstantAffineMap(bound);
|
||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||
}
|
||||
|
||||
void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) {
|
||||
if (boundMaps.size() % 2 == 0)
|
||||
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
||||
AffineMap map = builder.getSymbolIdentityMap();
|
||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||
_operands.emplace_back(operand);
|
||||
}
|
||||
} // namespace mlir
|
|
@ -0,0 +1,102 @@
|
|||
#pragma once
|
||||
|
||||
#include <queue>
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace onnf {
|
||||
|
||||
class KrnlDialectOperandParser {
|
||||
public:
|
||||
explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser)
|
||||
: _parser(parser), _builder(parser.getBuilder()){};
|
||||
|
||||
// Parse an optional operand.
|
||||
mlir::ParseResult ParseOptionalOperand(
|
||||
const mlir::Type& operandType, mlir::Value*& operand);
|
||||
|
||||
// Parse an optional operand and push it to an operand list.
|
||||
mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType,
|
||||
llvm::SmallVectorImpl<mlir::Value*>& operandList);
|
||||
|
||||
// Parse a required operand.
|
||||
mlir::ParseResult ParseOperand(
|
||||
const mlir::Type& operandType, mlir::Value*& operand);
|
||||
|
||||
// Parse a required operand and push it to an operand list.
|
||||
mlir::ParseResult ParseOperand(const mlir::Type& operandType,
|
||||
llvm::SmallVectorImpl<mlir::Value*>& operandList);
|
||||
|
||||
// Do we have more operands to parse?
|
||||
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
|
||||
|
||||
private:
|
||||
mlir::OpAsmParser& _parser;
|
||||
|
||||
mlir::Builder& _builder;
|
||||
|
||||
// A queue storing the parsed SSA id references.
|
||||
std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
|
||||
};
|
||||
|
||||
// Adapted from:
|
||||
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
|
||||
// Main difference is that it advances the iterator `begin` as it consumes
|
||||
// dimension and symbol operands.
|
||||
void printDimAndSymbolList(mlir::Operation::operand_iterator& begin,
|
||||
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p);
|
||||
|
||||
// Adapted from:
|
||||
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
|
||||
// Main difference is that it advances the iterator `boundOperandsBeg` as it
|
||||
// prints bound.
|
||||
void printBound(mlir::AffineMapAttr boundMap,
|
||||
mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix,
|
||||
mlir::OpAsmPrinter& p);
|
||||
} // namespace onnf
|
||||
|
||||
namespace mlir {
|
||||
|
||||
struct KrnlIterateOperandPack {
|
||||
KrnlIterateOperandPack(mlir::Builder& builder,
|
||||
llvm::ArrayRef<mlir::Value*> inputLoops,
|
||||
llvm::ArrayRef<mlir::Value*> optimizedLoops)
|
||||
: builder(builder),
|
||||
inputLoops(inputLoops),
|
||||
optimizedLoops(optimizedLoops) {
|
||||
_operands.insert(
|
||||
_operands.end(), optimizedLoops.begin(), optimizedLoops.end());
|
||||
}
|
||||
|
||||
void pushConstantBound(int64_t bound);
|
||||
|
||||
void pushOperandBound(mlir::Value* operand);
|
||||
|
||||
llvm::SmallVector<mlir::Value*, 8> getOperands() const { return _operands; }
|
||||
|
||||
mlir::ArrayAttr getAttributes() const {
|
||||
return builder.getArrayAttr(boundMaps);
|
||||
}
|
||||
|
||||
size_t getNumOptimizedLoops() const { return optimizedLoops.size(); }
|
||||
|
||||
size_t getNumInputLoops() const { return inputLoops.size(); }
|
||||
|
||||
private:
|
||||
int _boundIdx = 0;
|
||||
|
||||
llvm::SmallVector<mlir::Value*, 8> _operands;
|
||||
|
||||
llvm::SmallVector<mlir::Attribute, 8> boundMaps;
|
||||
|
||||
llvm::ArrayRef<mlir::Value*> inputLoops, optimizedLoops;
|
||||
|
||||
mlir::Builder& builder;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
|
@ -9,10 +9,10 @@
|
|||
#include <iostream>
|
||||
#include <queue>
|
||||
|
||||
#include "src/compiler/dialect/krnl/parser_helper.hpp"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
|
@ -24,6 +24,8 @@
|
|||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
|
||||
|
||||
#include "krnl_ops.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -52,26 +54,25 @@ void KrnlDefineLoopsOp::build(
|
|||
}
|
||||
|
||||
void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) {
|
||||
auto num_loop_attr = op.getAttrOfType<IntegerAttr>(op.getNumLoopsAttrName());
|
||||
p << "krnl.define_loops " << num_loop_attr.getValue().getSExtValue();
|
||||
auto numLoopAttr =
|
||||
op.getAttrOfType<IntegerAttr>(KrnlDefineLoopsOp::getNumLoopsAttrName());
|
||||
p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue();
|
||||
}
|
||||
|
||||
ParseResult parseKrnlDefineLoopsOp(
|
||||
OpAsmParser& parser, OperationState& result) {
|
||||
// Parse the attribute indicating number of loops defined.
|
||||
IntegerAttr num_loops;
|
||||
IntegerAttr numLoops;
|
||||
auto& builder = parser.getBuilder();
|
||||
auto int32_type = builder.getIntegerType(64);
|
||||
if (parser.parseAttribute(num_loops, int32_type,
|
||||
auto intType = builder.getIntegerType(64);
|
||||
if (parser.parseAttribute(numLoops, intType,
|
||||
KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes))
|
||||
return failure();
|
||||
|
||||
auto loop_types = llvm::SmallVector<Type, 4>(
|
||||
num_loops.getValue().getSExtValue(), LoopType::get(builder.getContext()));
|
||||
if (parser.addTypesToList(loop_types, result.types))
|
||||
auto loopTypes = llvm::SmallVector<Type, 4>(
|
||||
numLoops.getValue().getSExtValue(), LoopType::get(builder.getContext()));
|
||||
if (parser.addTypesToList(loopTypes, result.types))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -142,39 +143,14 @@ ParseResult parseKrnlOptimizeLoopsOp(
|
|||
* %i0 = 10 to N : %i1 = M to 20
|
||||
*/
|
||||
void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
||||
ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops,
|
||||
ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds,
|
||||
ArrayRef<int> bound_types) {
|
||||
KrnlIterateOperandPack operandPack) {
|
||||
// Record optimized loops and the number of such loops.
|
||||
result.addOperands(optimized_loops);
|
||||
result.addOperands(operandPack.getOperands());
|
||||
result.addAttribute(
|
||||
KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes());
|
||||
|
||||
result.addAttribute(getNumOptimizedLoopsAttrName(),
|
||||
builder->getI64IntegerAttr(optimized_loops.size()));
|
||||
|
||||
// Record input loops and the number of such loops.
|
||||
result.addOperands(input_loops);
|
||||
result.addAttribute(getNumInputLoopsAttrName(),
|
||||
builder->getI64IntegerAttr(input_loops.size()));
|
||||
|
||||
// Record bound either as attribute or from operand list.
|
||||
auto next_operand_bound = operand_bounds.begin();
|
||||
auto next_const_bound = const_bounds.begin();
|
||||
for (size_t i = 0; i < bound_types.size(); i++) {
|
||||
auto bound_type = bound_types[i];
|
||||
if (bound_type == 0) {
|
||||
// Constant bound.
|
||||
result.addAttribute(getBoundAttrName(i / 2, i % 2),
|
||||
builder->getI64IntegerAttr(*next_const_bound));
|
||||
next_const_bound = std::next(next_const_bound);
|
||||
} else {
|
||||
// Operand bound.
|
||||
result.addOperands(*next_operand_bound);
|
||||
next_operand_bound = std::next(next_operand_bound);
|
||||
}
|
||||
}
|
||||
|
||||
// Record bound types as attribute:
|
||||
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
|
||||
builder->getI32ArrayAttr(bound_types));
|
||||
builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops()));
|
||||
|
||||
// Create a region and a block for the body. The arguments of the region are
|
||||
// the loop induction variables; there can be multiple induction variables
|
||||
|
@ -182,7 +158,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
|||
Region* bodyRegion = result.addRegion();
|
||||
auto* body = new Block();
|
||||
auto body_args = llvm::SmallVector<Type, 4>(
|
||||
input_loops.size(), IndexType::get(builder->getContext()));
|
||||
operandPack.getNumInputLoops(), IndexType::get(builder->getContext()));
|
||||
body->addArguments(body_args);
|
||||
bodyRegion->push_back(body);
|
||||
|
||||
|
@ -192,57 +168,31 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
|||
void print(OpAsmPrinter& p, KrnlIterateOp& op) {
|
||||
p << "krnl.iterate(";
|
||||
// Print optimized loops:
|
||||
auto num_optimized_loops = op.getNumOptimizedLoops();
|
||||
p.printOperands(op.operand_begin(), op.operand_begin() + num_optimized_loops);
|
||||
auto numOptimizedLoops = op.getNumOptimizedLoops();
|
||||
p.printOperands(op.operand_begin(), op.operand_begin() + numOptimizedLoops);
|
||||
p << ") with (";
|
||||
|
||||
// Set up iterator to input loops:
|
||||
auto num_input_loops = op.getNumInputLoops();
|
||||
auto input_loop_begin = op.operand_begin() + num_optimized_loops;
|
||||
auto inductionVars = op.bodyRegion().begin()->getArguments();
|
||||
auto boundItr =
|
||||
op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
|
||||
.getValue()
|
||||
.begin();
|
||||
auto operandItr = op.operand_begin() + numOptimizedLoops;
|
||||
|
||||
// Set up iterators to operand bounds.
|
||||
auto next_operand_bound = input_loop_begin + num_input_loops;
|
||||
|
||||
// Function to print a lower or upper bound.
|
||||
auto print_bound = [&](ArrayRef<Attribute> bound_types, size_t idx) {
|
||||
IntegerAttr type = bound_types[idx].dyn_cast<IntegerAttr>();
|
||||
if (type.getValue().getSExtValue() == 0) {
|
||||
// Bound is an integer attribute.
|
||||
auto bound_idx = idx / 2;
|
||||
auto is_ub = idx % 2;
|
||||
IntegerAttr bound = op.getAttrOfType<IntegerAttr>(
|
||||
KrnlIterateOp::getBoundAttrName(bound_idx, is_ub));
|
||||
p << bound.getValue().getSExtValue();
|
||||
} else {
|
||||
// Bound is an operand.
|
||||
p.printOperand(*next_operand_bound);
|
||||
next_operand_bound = std::next(next_operand_bound);
|
||||
}
|
||||
};
|
||||
|
||||
auto induction_variables = op.bodyRegion().front().getArguments();
|
||||
ArrayRef<Attribute> bound_types =
|
||||
op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundTypesAttrName())
|
||||
.getValue();
|
||||
|
||||
// Print input loop operands, induction variables and their ranges.
|
||||
for (size_t i = 0; i < num_input_loops; i++) {
|
||||
if (i != 0)
|
||||
p << ", ";
|
||||
|
||||
p.printOperand(*std::next(input_loop_begin, i));
|
||||
std::string delimiter;
|
||||
for (auto& var : inductionVars) {
|
||||
p << delimiter;
|
||||
p.printOperand(*operandItr++);
|
||||
p << " -> ";
|
||||
|
||||
// Print induction variable block argument.
|
||||
p.printOperand(induction_variables[i]);
|
||||
p.printOperand(var);
|
||||
p << " = ";
|
||||
|
||||
print_bound(bound_types, 2 * i); // Print lower bound.
|
||||
onnf::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "max", p);
|
||||
p << " to ";
|
||||
print_bound(bound_types, 2 * i + 1); // Print upper bound.
|
||||
onnf::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "min", p);
|
||||
delimiter = ", ";
|
||||
}
|
||||
p << ")";
|
||||
|
||||
p << ")";
|
||||
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
}
|
||||
|
@ -250,80 +200,109 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) {
|
|||
ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
||||
auto builder = parser.getBuilder();
|
||||
auto context = builder.getContext();
|
||||
onnf::KrnlDialectOperandParser operand_parser(parser);
|
||||
onnf::KrnlDialectOperandParser operandParser(parser);
|
||||
|
||||
// Parse optimized loops:
|
||||
SmallVector<OpAsmParser::OperandType, 4> num_optimized_loops;
|
||||
SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs;
|
||||
if (parser.parseOperandList(
|
||||
num_optimized_loops, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(num_optimized_loops,
|
||||
optimizedLoopRefs, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(optimizedLoopRefs,
|
||||
LoopType::get(result.getContext()), result.operands))
|
||||
return failure();
|
||||
|
||||
// Record how many optimized loops did we parse.
|
||||
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
|
||||
builder.getI64IntegerAttr(num_optimized_loops.size()));
|
||||
builder.getI64IntegerAttr(optimizedLoopRefs.size()));
|
||||
|
||||
// Parse input loops and their lower and upper bounds.
|
||||
SmallVector<OpAsmParser::OperandType, 4> in_loop_refs, induction_var_refs;
|
||||
SmallVector<Value*, 4> in_loop_operands, operand_bounds;
|
||||
SmallVector<Attribute, 4> bound_types;
|
||||
SmallVector<IntegerAttr, 4> const_bounds;
|
||||
SmallVector<OpAsmParser::OperandType, 4> inductionVarRefs;
|
||||
SmallVector<Attribute, 4> boundMaps;
|
||||
|
||||
if (parser.parseKeyword("with") || parser.parseLParen())
|
||||
return failure();
|
||||
|
||||
// A function to parse a lower or upper bound.
|
||||
auto parse_bound = [&result, &builder, &operand_parser, &parser, &bound_types,
|
||||
&operand_bounds, &const_bounds](
|
||||
bool is_ub, size_t bound_pair_count) -> ParseResult {
|
||||
auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps](
|
||||
bool isUpper) -> ParseResult {
|
||||
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if
|
||||
// the map has multiple results.
|
||||
bool failedToParsedMinMax =
|
||||
failed(parser.parseOptionalKeyword(isUpper ? "min" : "max"));
|
||||
|
||||
// Try parse an SSA operand.
|
||||
Value* bound;
|
||||
operand_parser.ParseOptionalOperand(builder.getIndexType(), bound);
|
||||
if (succeeded(operandParser.ParseOptionalOperand(
|
||||
builder.getIndexType(), result.operands))) {
|
||||
AffineMap map = builder.getSymbolIdentityMap();
|
||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (bound != nullptr) {
|
||||
// Parsed an SSA id as bound.
|
||||
operand_bounds.emplace_back(bound);
|
||||
// Record bound_type as an operand type.
|
||||
bound_types.emplace_back(builder.getI32IntegerAttr(0));
|
||||
} else {
|
||||
// Bound is not an SSA id, then it must be an integer.
|
||||
// Parse an integer constant attribute.
|
||||
IntegerAttr boundAttr;
|
||||
if (parser.parseAttribute(boundAttr, builder.getIndexType(),
|
||||
KrnlIterateOp::getBoundAttrName(bound_pair_count, is_ub),
|
||||
result.attributes))
|
||||
// Bound is not an SSA id, then it must be an integer.
|
||||
// Parse an integer constant attribute.
|
||||
// Get the attribute location.
|
||||
llvm::SMLoc attrLoc = parser.getCurrentLocation();
|
||||
Attribute boundAttr;
|
||||
llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer;
|
||||
if (parser.parseAttribute(
|
||||
boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer))
|
||||
return failure();
|
||||
|
||||
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
|
||||
unsigned currentNumOperands = result.operands.size();
|
||||
unsigned numDims = 0;
|
||||
if (parseDimAndSymbolList(parser, result.operands, numDims))
|
||||
return failure();
|
||||
const_bounds.emplace_back(
|
||||
builder.getIntegerAttr(builder.getIndexType(), boundAttr.getValue()));
|
||||
|
||||
// Record that the bound_type is a constant integer attribute.
|
||||
bound_types.emplace_back(builder.getI32IntegerAttr(1));
|
||||
auto map = affineMapAttr.getValue();
|
||||
if (map.getNumDims() != numDims)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"dim operand count and integer set dim count must match");
|
||||
|
||||
unsigned numDimAndSymbolOperands =
|
||||
result.operands.size() - currentNumOperands;
|
||||
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"symbol operand count and integer set symbol count must match");
|
||||
|
||||
// If the map has multiple results, make sure that we parsed the min/max
|
||||
// prefix.
|
||||
if (map.getNumResults() > 1 && failedToParsedMinMax) {
|
||||
if (isUpper)
|
||||
return parser.emitError(attrLoc,
|
||||
"upper loop bound affine map with multiple "
|
||||
"results requires 'min' prefix");
|
||||
return parser.emitError(attrLoc,
|
||||
"lower loop bound affine mapwith "
|
||||
"multiple results requires 'max' prefix");
|
||||
}
|
||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
|
||||
AffineMap map =
|
||||
builder.getConstantAffineMap(integerAttr.getValue().getSExtValue());
|
||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||
}
|
||||
};
|
||||
|
||||
bool keep_parsing; // Do we keep parsing loops/bounds?
|
||||
size_t bound_pair_count = 0; // Record the number of bound pairs parsed.
|
||||
bool keepParsing; // Do we keep parsing loops/bounds?
|
||||
do {
|
||||
// Parse an input loop operand;
|
||||
Value* in_loop_operand;
|
||||
operand_parser.ParseOperand(LoopType::get(context), in_loop_operand);
|
||||
in_loop_operands.emplace_back(in_loop_operand);
|
||||
|
||||
operandParser.ParseOperand(LoopType::get(context), result.operands);
|
||||
parser.parseArrow();
|
||||
|
||||
// Parse induction variable.
|
||||
OpAsmParser::OperandType induction_var;
|
||||
if (parser.parseRegionArgument(induction_var) || parser.parseEqual())
|
||||
OpAsmParser::OperandType inductionVar;
|
||||
if (parser.parseRegionArgument(inductionVar) || parser.parseEqual())
|
||||
return failure();
|
||||
induction_var_refs.emplace_back(induction_var);
|
||||
inductionVarRefs.emplace_back(inductionVar);
|
||||
|
||||
// Parse bound par (min to max).
|
||||
if (parse_bound(false, bound_pair_count) || parser.parseKeyword("to") ||
|
||||
parse_bound(true, bound_pair_count))
|
||||
if (parseBound(/*isUpper=*/false) || parser.parseKeyword("to") ||
|
||||
parseBound(/*isUpper=*/true))
|
||||
return failure();
|
||||
|
||||
bound_pair_count++;
|
||||
// We may fail to parse a comma if an operand bound is followed by
|
||||
// a comma and the next input loop operand, in which case
|
||||
// the entire "{operand bound}, {input_loop_operand}" sequence will
|
||||
|
@ -331,33 +310,19 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
|||
parser.parseOptionalComma();
|
||||
|
||||
// If we don't see a RParen token, we keep parsing.
|
||||
keep_parsing = failed(parser.parseOptionalRParen());
|
||||
} while (keep_parsing);
|
||||
keepParsing = failed(parser.parseOptionalRParen());
|
||||
} while (keepParsing);
|
||||
|
||||
// At this point, there shouldn't be any operands left to parse.
|
||||
if (operand_parser.has_operand_left())
|
||||
if (operandParser.hasOperandLeft())
|
||||
return parser.emitError(parser.getCurrentLocation());
|
||||
result.addAttribute(
|
||||
KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps));
|
||||
|
||||
// Record how many input loops did we parse.
|
||||
result.addOperands(in_loop_operands);
|
||||
result.addAttribute(KrnlIterateOp::getNumInputLoopsAttrName(),
|
||||
builder.getI64IntegerAttr(in_loop_operands.size()));
|
||||
|
||||
// Add operand bounds to the list of operands of current operation.
|
||||
result.addOperands(operand_bounds);
|
||||
|
||||
// A list of 2N elements where the (2n) and (2n+1) th element specifies
|
||||
// whether the lower and upper bound of the n'th induction variable is stored
|
||||
// as an operand or as an attribute. N being the number of input loops
|
||||
// specified in this krnl.iterate operation.
|
||||
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
|
||||
builder.getArrayAttr(bound_types));
|
||||
|
||||
// Parse the schedule body region.
|
||||
Region* region = result.addRegion();
|
||||
SmallVector<Type, 4> induction_var_types(
|
||||
induction_var_refs.size(), builder.getIndexType());
|
||||
if (parser.parseRegion(*region, induction_var_refs, induction_var_types))
|
||||
SmallVector<Type, 4> inductionVarTypes(
|
||||
inductionVarRefs.size(), builder.getIndexType());
|
||||
if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes))
|
||||
return failure();
|
||||
|
||||
// Ensure iterate region is closed off with krnl.terminate.
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
|
||||
#include "src/compiler/dialect/krnl/krnl_types.hpp"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
|
||||
def Krnl_Dialect : Dialect {
|
||||
let name = "krnl";
|
||||
let cppNamespace = "";
|
||||
|
@ -119,17 +120,14 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
|||
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops, "
|
||||
"ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds, "
|
||||
"ArrayRef<int> bound_types">
|
||||
"KrnlIterateOperandPack operandPack">
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
||||
// In krnl.iterate operation, three types of SSA values are stored:
|
||||
// In krnl.iterate operation, operands are stored as such
|
||||
// - Optimized krnl.loops.
|
||||
// - Input krnl.loops.
|
||||
// - SSA value based induction variable bound (parametric bound).
|
||||
// - Input krnl.loops and their operand bounds. (TODO(Tian) explain better how we store them).
|
||||
|
||||
// We record the number of optimized and input loops to separate these three
|
||||
// group of operands out.
|
||||
static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; }
|
||||
|
@ -143,32 +141,8 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
|||
return num_optimized_loops;
|
||||
}
|
||||
|
||||
static StringRef getNumInputLoopsAttrName() { return "num_input_loops"; }
|
||||
|
||||
int64_t getNumInputLoops() {
|
||||
auto num_loops =
|
||||
getAttrOfType<IntegerAttr>(
|
||||
getNumInputLoopsAttrName())
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
return num_loops;
|
||||
}
|
||||
|
||||
// Constant bounds are stored here as a list attribute.
|
||||
static StringRef getConstantBoundsAttrName() { return "constant_bounds"; }
|
||||
|
||||
// Store type of each bound as three types:
|
||||
// - 0 = constant attribute.
|
||||
// - 1 = operand type.
|
||||
// - 2 = affine maps (TODO).
|
||||
static StringRef getBoundTypesAttrName() { return "bound_types"; }
|
||||
|
||||
// Get dynamic attribute name for the i-th lower and upper bound.
|
||||
static std::string getBoundAttrName(int64_t i, bool is_ub) {
|
||||
std::string bound_type = is_ub ? "_ub" : "_lb";
|
||||
std::string bound_idx = std::to_string(i);
|
||||
return "__bound_" + bound_idx + bound_type;
|
||||
}
|
||||
// Get name of the attribute for storing bound represented using affine maps.
|
||||
static StringRef getBoundsAttrName() { return "bounds"; }
|
||||
}];
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
//===------------------ parser_helper.cpp - MLIR Operations ---------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "parser_helper.hpp"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
|
||||
namespace onnf {
|
||||
|
||||
mlir::ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||
mlir::Type operand_type, mlir::Value*& operand) {
|
||||
// If operand queue is empty, parse more operands and cache them.
|
||||
if (_operand_ref_queue.empty()) {
|
||||
// Parse operand types:
|
||||
llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> operand_refs;
|
||||
_parser.parseOperandList(operand_refs);
|
||||
|
||||
// Record operands:
|
||||
for (auto& operand_ref : operand_refs)
|
||||
_operand_ref_queue.emplace(operand_ref);
|
||||
}
|
||||
|
||||
// If we parsed some operand reference(s), resolve the ref to an operand:
|
||||
if (!_operand_ref_queue.empty()) {
|
||||
auto operand_ref = _operand_ref_queue.front();
|
||||
_operand_ref_queue.pop();
|
||||
|
||||
llvm::SmallVector<mlir::Value*, 1> operands;
|
||||
_parser.resolveOperand(operand_ref, operand_type, operands);
|
||||
operand = operands.front();
|
||||
return mlir::success();
|
||||
} else {
|
||||
operand = nullptr;
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
mlir::ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||
mlir::Type operand_type, mlir::Value*& operand) {
|
||||
ParseOptionalOperand(operand_type, operand);
|
||||
if (operand == nullptr)
|
||||
return _parser.emitError(
|
||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace onnf
|
|
@ -1,46 +0,0 @@
|
|||
//===------------------ parser_helper.hpp - MLIR Operations ---------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <queue>
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace onnf {
|
||||
|
||||
class KrnlDialectOperandParser {
|
||||
public:
|
||||
KrnlDialectOperandParser(mlir::OpAsmParser& parser)
|
||||
: _parser(parser), _builder(parser.getBuilder()){};
|
||||
|
||||
// Parse an optional operand.
|
||||
mlir::ParseResult ParseOptionalOperand(
|
||||
mlir::Type operand_type, mlir::Value*& operand);
|
||||
|
||||
// Parse a required operand.
|
||||
mlir::ParseResult ParseOperand(
|
||||
mlir::Type operand_type, mlir::Value*& operand);
|
||||
|
||||
// Do we have more operands to parse?
|
||||
bool has_operand_left() { return !_operand_ref_queue.empty(); }
|
||||
|
||||
private:
|
||||
mlir::OpAsmParser& _parser;
|
||||
|
||||
mlir::Builder& _builder;
|
||||
|
||||
// A queue storing the parsed SSA id references.
|
||||
std::queue<mlir::OpAsmParser::OperandType> _operand_ref_queue;
|
||||
};
|
||||
|
||||
} // namespace onnf
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
|
||||
|
@ -43,13 +44,12 @@ static MemRefType convertTensorToMemRef(TensorType type) {
|
|||
}
|
||||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
static Value* insertAllocAndDealloc(
|
||||
MemRefType type, Location loc, PatternRewriter& rewriter,
|
||||
Value *oldMemRef = nullptr) {
|
||||
static Value* insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter& rewriter, Value* oldMemRef = nullptr) {
|
||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||
AllocOp alloc;
|
||||
if (oldMemRef) {
|
||||
SmallVector<Value *, 4> allocOperands;
|
||||
SmallVector<Value*, 4> allocOperands;
|
||||
auto memRefShape = type.getShape();
|
||||
for (int i = 0; i < memRefShape.size(); ++i)
|
||||
if (memRefShape[i] < 0)
|
||||
|
@ -95,7 +95,7 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
|||
// dimensions with the result at this pre-optimization phase.
|
||||
// TODO: verify that dimensions match.
|
||||
// TODO: can the dimension of the result differ after optimizations?
|
||||
Value *alloc;
|
||||
Value* alloc;
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||
else
|
||||
|
@ -122,33 +122,22 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
|||
}
|
||||
Block& optimizationBlock = optimizedLoopsOp.region().front();
|
||||
|
||||
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||
// Iterate over the loop nest.
|
||||
// TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape
|
||||
// to KrnlIterateOp instead.
|
||||
SmallVector<Value*, 8> operandBounds;
|
||||
SmallVector<int64_t, 8> constBounds;
|
||||
SmallVector<int, 16> boundTypes;
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
if (memRefShape[i] < 0) {
|
||||
// This is a dynamic value, hence use operands.
|
||||
// Lower bound
|
||||
constBounds.push_back(0);
|
||||
boundTypes.push_back(0);
|
||||
// Upper bound
|
||||
operandBounds.push_back(
|
||||
pack.pushConstantBound(0);
|
||||
pack.pushOperandBound(
|
||||
rewriter.create<DimOp>(loc, operands[0], i).getResult());
|
||||
boundTypes.push_back(1);
|
||||
} else {
|
||||
// Lower bound
|
||||
constBounds.push_back(0);
|
||||
boundTypes.push_back(0);
|
||||
// Upper bound
|
||||
constBounds.push_back(memRefShape[i]);
|
||||
boundTypes.push_back(0);
|
||||
pack.pushConstantBound(0);
|
||||
pack.pushConstantBound(memRefShape[i]);
|
||||
}
|
||||
}
|
||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, originalLoops,
|
||||
optimizedLoops, operandBounds, constBounds, boundTypes);
|
||||
|
||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||
Block& iterationBlock = iterateOp.bodyRegion().front();
|
||||
|
||||
// Now perform the insertions into the body of the
|
||||
|
@ -169,14 +158,12 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
|||
SmallVector<Value*, 4> loopIVs;
|
||||
for (auto arg : iterationBlock.getArguments())
|
||||
loopIVs.push_back(arg);
|
||||
auto loadedFirstVal =
|
||||
rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
||||
auto loadedSecondVal =
|
||||
rewriter.create<LoadOp>(loc, operands[1], loopIVs);
|
||||
auto loadedFirstVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
||||
auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs);
|
||||
|
||||
// TODO: Choose type of the Add for now use the Float Add.
|
||||
auto addOpResult = rewriter.create<AddFOp>(
|
||||
loc, loadedFirstVal, loadedSecondVal);
|
||||
auto addOpResult =
|
||||
rewriter.create<AddFOp>(loc, loadedFirstVal, loadedSecondVal);
|
||||
|
||||
// Store result in the resulting array.
|
||||
rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs);
|
||||
|
@ -209,8 +196,8 @@ struct TensorTypeConverter : public TypeConverter {
|
|||
/// inputs. Once unranked results can be handled gracefully this
|
||||
/// override needs to be removed in favour of the original MLIR one.]
|
||||
bool isSignatureLegal(FunctionType funcType) {
|
||||
return llvm::all_of(funcType.getInputs(),
|
||||
[this](Type type) { return isLegal(type); });
|
||||
return llvm::all_of(
|
||||
funcType.getInputs(), [this](Type type) { return isLegal(type); });
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -272,8 +259,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our `illegal`
|
||||
// operations were not converted successfully.
|
||||
if (failed(applyPartialConversion(
|
||||
module, target, patterns)))
|
||||
if (failed(applyPartialConversion(module, target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
|
|
@ -17,9 +17,10 @@ class Pass;
|
|||
|
||||
std::unique_ptr<Pass> createShapeInferencePass();
|
||||
|
||||
/// Pass for lowering frontend dialects to Krnl IR dialect.
|
||||
std::unique_ptr<mlir::Pass> createLowerToKrnlPass();
|
||||
/// Add pass for lowering to Krnl IR.
|
||||
std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||
|
||||
// TODO: Add pass for lowering to LLVM IR.
|
||||
/// Pass for lowering frontend dialects to Krnl IR dialect.
|
||||
std::unique_ptr<Pass> createLowerKrnlPass();
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -75,7 +75,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
|||
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
||||
auto results = terminator_op->getOperandTypes();
|
||||
f.setType(FunctionType::get(f.getType().getInputs(),
|
||||
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
||||
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,8 +6,7 @@ target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT})
|
|||
|
||||
target_link_libraries(onnf-opt compiler ${MLIRLibs})
|
||||
whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs})
|
||||
whole_archive_link_onnf(onnf-opt onnf_lower_frontend)
|
||||
whole_archive_link_onnf(onnf-opt onnf_shape_inference)
|
||||
whole_archive_link_onnf(onnf-opt onnf_transform onnf_lower_frontend onnf_shape_inference)
|
||||
|
||||
# TODO: need to investigate how to whole-archive link compiler pass to onnf-opt.
|
||||
target_link_libraries(onnf-opt compiler)
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
add_library(onnf_transform lower_krnl.cpp)
|
||||
|
||||
target_include_directories(onnf_transform
|
||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||
${ONNF_SRC_ROOT})
|
||||
target_link_libraries(onnf_transform ${MLIRLibs})
|
||||
add_dependencies(onnf_transform gen_krnl_ops)
|
|
@ -0,0 +1,161 @@
|
|||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
#include "src/compiler/pass/passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Krnl to Affine Rewrite Patterns: KrnlIterate operation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
||||
using OpRewritePattern<KrnlIterateOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
KrnlIterateOp iterateOp, PatternRewriter& rewriter) const override {
|
||||
auto boundMapAttrs =
|
||||
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
|
||||
.getValue();
|
||||
auto operandItr =
|
||||
iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops();
|
||||
SmallVector<AffineForOp, 4> nestedForOps;
|
||||
for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) {
|
||||
// Consume input loop operand, currently do not do anything with it.
|
||||
operandItr++;
|
||||
|
||||
// Organize operands into lower/upper bounds in affine.for ready formats.
|
||||
SmallVector<Value*, 4> lbOperands, ubOperands;
|
||||
AffineMap lbMap, ubMap;
|
||||
for (int boundType = 0; boundType < 2; boundType++) {
|
||||
auto& operands = boundType == 0 ? lbOperands : ubOperands;
|
||||
auto& map = boundType == 0 ? lbMap : ubMap;
|
||||
map = boundMapAttrs[boundIdx + boundType]
|
||||
.cast<AffineMapAttr>()
|
||||
.getValue();
|
||||
operands.insert(
|
||||
operands.end(), operandItr, operandItr + map.getNumInputs());
|
||||
std::advance(operandItr, map.getNumInputs());
|
||||
}
|
||||
|
||||
nestedForOps.emplace_back(rewriter.create<AffineForOp>(
|
||||
iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap));
|
||||
rewriter.setInsertionPoint(nestedForOps.back().getBody(),
|
||||
nestedForOps.back().getBody()->begin());
|
||||
}
|
||||
|
||||
// Replace induction variable references from those introduced by a
|
||||
// single krnl.iterate to those introduced by multiple affine.for
|
||||
// operations.
|
||||
for (size_t i = 0; i < nestedForOps.size() - 1; i++) {
|
||||
auto iterateIV = iterateOp.bodyRegion().front().getArgument(0);
|
||||
auto forIV = nestedForOps[i].getBody()->getArgument(0);
|
||||
iterateIV->replaceAllUsesWith(forIV);
|
||||
iterateOp.bodyRegion().front().eraseArgument(0);
|
||||
}
|
||||
|
||||
// Pop krnl.iterate body region block arguments, leave the last one
|
||||
// for convenience (it'll be taken care of by region inlining).
|
||||
while (iterateOp.bodyRegion().front().getNumArguments() > 1)
|
||||
iterateOp.bodyRegion().front().eraseArgument(0);
|
||||
|
||||
// Transfer krnl.iterate region to innermost for op.
|
||||
auto innermostForOp = nestedForOps.back();
|
||||
innermostForOp.region().getBlocks().clear();
|
||||
rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(),
|
||||
innermostForOp.region().end());
|
||||
|
||||
rewriter.eraseOp(iterateOp);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Krnl to Affine Rewrite Patterns: KrnlTerminator operation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
|
||||
public:
|
||||
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
KrnlTerminatorOp op, PatternRewriter& rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Krnl to Affine Rewrite Patterns: KrnlDefineLoops operation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
|
||||
public:
|
||||
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
KrnlDefineLoopsOp op, PatternRewriter& rewriter) const override {
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> {
|
||||
public:
|
||||
using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
KrnlOptimizeLoopsOp op, PatternRewriter& rewriter) const override {
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KrnlToAffineLoweringPass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This is a partial lowering to affine loops of the krnl dialect operations.
|
||||
/// At this stage the dialect will contain standard operations as well like
|
||||
/// add and multiply, this pass will leave these operations intact.
|
||||
namespace {
|
||||
struct KrnlToAffineLoweringPass
|
||||
: public FunctionPass<KrnlToAffineLoweringPass> {
|
||||
void runOnFunction() final;
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void KrnlToAffineLoweringPass::runOnFunction() {
|
||||
auto function = getFunction();
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
|
||||
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
|
||||
// We expect IR to be free of Krnl Dialect Ops.
|
||||
target.addIllegalDialect<KrnlOpsDialect>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext());
|
||||
|
||||
if (failed(applyPartialConversion(getFunction(), target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
|
||||
return std::make_unique<KrnlToAffineLoweringPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<KrnlToAffineLoweringPass> pass(
|
||||
"lower-krnl", "Lower Krnl dialect.");
|
16
src/main.cpp
16
src/main.cpp
|
@ -28,6 +28,7 @@
|
|||
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "llvm/Bitcode/BitcodeWriter.h"
|
||||
#include "llvm/Support/FileUtilities.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
@ -38,6 +39,8 @@
|
|||
#include "src/compiler/pass/passes.hpp"
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
@ -125,7 +128,20 @@ int main(int ac, char* av[]) {
|
|||
pm.addPass(mlir::createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createLowerToKrnlPass());
|
||||
pm.addPass(mlir::createLowerKrnlPass());
|
||||
pm.addPass(mlir::createLowerAffinePass());
|
||||
pm.addPass(mlir::createLowerToCFGPass());
|
||||
pm.addPass(mlir::createLowerToLLVMPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.run(*module);
|
||||
|
||||
// Write LLVM bitcode to disk.
|
||||
std::error_code EC;
|
||||
llvm::raw_fd_ostream moduleBitcodeStream(
|
||||
"model.bc", EC, llvm::sys::fs::F_None);
|
||||
llvm::WriteBitcodeToFile(
|
||||
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
||||
moduleBitcodeStream.flush();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
// RUN: onnf-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s
|
||||
// RUN: onnf-opt %s | FileCheck %s
|
||||
|
||||
// GENERIC-DAG: #{{.*}} = () -> (0)
|
||||
// GENERIC-DAG: #{{.*}} = () -> (10)
|
||||
// GENERIC-DAG: #{{.*}} = () -> (1)
|
||||
// GENERIC-DAG: #{{.*}} = () -> (11)
|
||||
// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 - d1)
|
||||
// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 + d1)
|
||||
|
||||
func @simple_iterate(%N : index) {
|
||||
%ii, %ij, %ik = krnl.define_loops 3
|
||||
%oi, %oj, %ok = krnl.optimize_loops {
|
||||
krnl.return_loops %ii, %ij, %ik
|
||||
} : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||
|
||||
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
|
||||
// GENERIC-NEXT: "krnl.terminate"() : () -> ()
|
||||
// GENERIC-NEXT: bounds = [#{{.*}}, #{{.*}}, #{{.*}}, #{{.*}}]
|
||||
|
||||
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 1 to 11) {
|
||||
krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 1 to 11) {
|
||||
|
||||
}
|
||||
|
||||
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
|
||||
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) {
|
||||
krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to 10, %ij -> %j = 0 to 10) {
|
||||
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}) ( {
|
||||
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
|
||||
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10) {
|
||||
krnl.iterate(%ok) with (%ik -> %k = 0 to 10) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
|
||||
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to %{{.*}}, %{{.*}} -> %{{.*}} = 0 to 10) {
|
||||
krnl.iterate(%oi, %oj) with (%ii -> %i = 0 to %N, %ij -> %j = 0 to 10) {
|
||||
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func @affine_map_bound(%N : index) {
|
||||
%ii, %ij, %ik = krnl.define_loops 3
|
||||
%oi, %oj, %ok = krnl.optimize_loops {
|
||||
krnl.return_loops %ii, %ij, %ik
|
||||
} : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||
|
||||
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
|
||||
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) {
|
||||
krnl.iterate(%oi, %oj) with (%ii -> %i = ()->(0)() to ()->(10)(), %ij -> %j = 0 to 10) {
|
||||
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
|
||||
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = #{{.*}}(%{{.*}}, %{{.*}}) to #{{.*}}(%{{.*}}, %{{.*}})) {
|
||||
krnl.iterate(%ok) with (%ik -> %k = (d0, d1)->(d0 - d1)(%i, %j) to (d0, d1)->(d0 + d1)(%i, %j)) {
|
||||
|
||||
}
|
||||
|
||||
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
|
||||
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = max #map{{.*}}(%{{.*}}, %{{.*}}) to min #map{{.*}}(%{{.*}}, %{{.*}})[%{{.*}}]) {
|
||||
krnl.iterate(%ok) with (%ik -> %k = max (d0, d1)->(d0 - d1, 0)(%i, %j) to min (d0, d1)[s0]->(d0 + d1, s0)(%i, %j)[%N]) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue