Merge pull request #9 from clang-ykt/use-mlir-in-llvm-project

Update to latest MLIR
This commit is contained in:
Tian Jin 2020-01-06 16:14:27 -05:00 committed by GitHub
commit 322002f509
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 214 additions and 227 deletions

View File

@ -7,52 +7,33 @@ jobs:
steps: steps:
- checkout - checkout
- run: - run:
name: "Pull Submodules" name: Pull Submodules
command: | command: |
git submodule update --init --recursive git submodule update --init --recursive
- run: - run:
name: Check current directory name: Installing GCC, CMake, Ninja, Protobuf
command: pwd command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler
- run: # Use cached mlir installation if possible.
name: Check current directory content
command: ls
- run:
name: Installing GCC
command: 'sudo apt-get update && sudo apt-get install -y gcc g++'
- run:
name: Install CMAKE
command: 'sudo apt-get update && sudo apt-get install -y cmake ninja-build'
- run:
name: Install Protobuf
command: 'sudo apt-get update && sudo apt-get install -y protobuf-compiler'
- run:
name: Check gcc version
command: gcc --version
- restore_cache: - restore_cache:
key: ONNF-MLIR-{{ arch }} key: V2-LLVM-PROJECT-{{ arch }}
- run: - run:
name: Install MLIR name: Install MLIR
command: | command: |
# Check whether cache restoration succeeds by checking whether
# mlir-opt executable exists.
if [ ! -f llvm-project/build/bin/mlir-opt ]; then if [ ! -f llvm-project/build/bin/mlir-opt ]; then
git clone https://github.com/llvm/llvm-project.git export MAKEFLAGS=-j4
cd llvm-project && git checkout 9b6ad8466bb8b97082b705270603ad7f4559e931 && cd .. source .circleci/install-mlir.sh
git clone https://github.com/tensorflow/mlir llvm-project/llvm/projects/mlir
cd llvm-project/llvm/projects/mlir && git checkout 0710266d0f56cf6ab0f437badbd7416b6cecdf5f && cd ../../../..
mkdir llvm-project/build
cd llvm-project/build
cmake -G Ninja ../llvm -DLLVM_ENABLE_RTTI=ON -DLLVM_BUILD_EXAMPLES=OFF -DLLVM_TARGETS_TO_BUILD="host" -DCMAKE_BUILD_TYPE=Release
CMAKE_EXE_LINKER_FLAGS="-Wl,--reduce-memory-overheads -Wl,--hash-size=512" cmake --build . --target check-mlir -- -j 4
fi fi
- save_cache: - save_cache:
key: ONNF-MLIR-{{ arch }} key: V2-LLVM-PROJECT-{{ arch }}
paths: paths:
- llvm-project - llvm-project
- run: - run:
name: Install ONNF name: Install ONNF
command: | command: |
mkdir build && cd build mkdir build && cd build
LLVM_SRC=$(pwd)/../llvm-project/llvm LLVM_BUILD=$(pwd)/../llvm-project/build cmake .. LLVM_PROJ_SRC=$(pwd)/../llvm-project/ LLVM_PROJ_BUILD=$(pwd)/../llvm-project/build cmake ..
make all make all
LIT_OPTS=-v make check-mlir-lit LIT_OPTS=-v make check-mlir-lit
- run: - run:

12
.circleci/install-mlir.sh Normal file
View File

@ -0,0 +1,12 @@
git clone https://github.com/llvm/llvm-project.git
mkdir llvm-project/build
cd llvm-project/build
cmake -G Ninja ../llvm \
-DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_BUILD_EXAMPLES=ON \
-DLLVM_TARGETS_TO_BUILD="host" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON
cmake --build . --target check-mlir -- ${MAKEFLAGS}

View File

@ -1,38 +1,38 @@
# Path to LLVM source folder. # Path to LLVM source folder.
if(DEFINED ENV{LLVM_SRC}) if(DEFINED ENV{LLVM_PROJ_SRC})
set(LLVM_SRC $ENV{LLVM_SRC}) set(LLVM_PROJ_SRC $ENV{LLVM_PROJ_SRC})
if(EXISTS ${LLVM_SRC}) if(EXISTS ${LLVM_PROJ_SRC})
message(STATUS "LLVM_SRC " ${LLVM_SRC}) message(STATUS "LLVM_PROJ_SRC " ${LLVM_PROJ_SRC})
else() else()
message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: " message(FATAL_ERROR "The path specified by LLVM_PROJ_SRC does not exist: "
${LLVM_SRC}) ${LLVM_PROJ_SRC})
endif() endif()
else() else()
message(FATAL_ERROR "env variable LLVM_SRC not set") message(FATAL_ERROR "env variable LLVM_PROJ_SRC not set")
endif() endif()
# Path to LLVM build folder # Path to LLVM build folder
if(DEFINED ENV{LLVM_BUILD}) if(DEFINED ENV{LLVM_PROJ_BUILD})
set(LLVM_BUILD $ENV{LLVM_BUILD}) set(LLVM_PROJ_BUILD $ENV{LLVM_PROJ_BUILD})
if(EXISTS ${LLVM_BUILD}) if(EXISTS ${LLVM_PROJ_BUILD})
message(STATUS "LLVM_BUILD " ${LLVM_BUILD}) message(STATUS "LLVM_PROJ_BUILD " ${LLVM_PROJ_BUILD})
else() else()
message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: " message(FATAL_ERROR "The path specified by LLVM_PROJ_BUILD does not exist: "
${LLVM_BUILD}) ${LLVM_PROJ_BUILD})
endif() endif()
else() else()
message(FATAL_ERROR "env variable LLVM_BUILD not set") message(FATAL_ERROR "env variable LLVM_PROJ_BUILD not set")
endif() endif()
# LLVM project lib folder # LLVM project lib folder
set(LLVM_PROJECT_LIB ${LLVM_BUILD}/lib) set(LLVM_PROJECT_LIB ${LLVM_PROJ_BUILD}/lib)
# Include paths for MLIR # Include paths for MLIR
set(LLVM_SRC_INCLUDE_PATH ${LLVM_SRC}/include) set(LLVM_SRC_INCLUDE_PATH ${LLVM_PROJ_SRC}/llvm/include)
set(LLVM_BIN_INCLUDE_PATH ${LLVM_BUILD}/include) set(LLVM_BIN_INCLUDE_PATH ${LLVM_PROJ_BUILD}/include)
set(MLIR_SRC_INCLUDE_PATH ${LLVM_SRC}/projects/mlir/include) set(MLIR_SRC_INCLUDE_PATH ${LLVM_PROJ_SRC}/mlir/include)
set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include) set(MLIR_BIN_INCLUDE_PATH ${LLVM_PROJ_BUILD}/tools/mlir/include)
set(MLIR_TOOLS_DIR ${LLVM_BUILD}/bin) set(MLIR_TOOLS_DIR ${LLVM_PROJ_BUILD}/bin)
set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/bin) set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/bin)
set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir) set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
@ -173,7 +173,7 @@ function(whole_archive_link target lib_dir)
endfunction(whole_archive_link) endfunction(whole_archive_link)
function(whole_archive_link_mlir target) function(whole_archive_link_mlir target)
whole_archive_link(${target} ${LLVM_BUILD}/lib ${ARGN}) whole_archive_link(${target} ${LLVM_PROJ_BUILD}/lib ${ARGN})
endfunction(whole_archive_link_mlir) endfunction(whole_archive_link_mlir)
function(whole_archive_link_onnf target) function(whole_archive_link_onnf target)
@ -184,7 +184,7 @@ function(whole_archive_link_onnf target)
endfunction(whole_archive_link_onnf) endfunction(whole_archive_link_onnf)
set(LLVM_CMAKE_DIR set(LLVM_CMAKE_DIR
"${LLVM_BUILD}/lib/cmake/llvm" "${LLVM_PROJ_BUILD}/lib/cmake/llvm"
CACHE PATH "Path to LLVM cmake modules") CACHE PATH "Path to LLVM cmake modules")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(AddLLVM) include(AddLLVM)
@ -205,5 +205,5 @@ endfunction()
# table gen utility itself can be detected and cause re-compilation of .td file. # table gen utility itself can be detected and cause re-compilation of .td file.
add_executable(mlir-tblgen IMPORTED) add_executable(mlir-tblgen IMPORTED)
set_property(TARGET mlir-tblgen set_property(TARGET mlir-tblgen
PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen) PROPERTY IMPORTED_LOCATION ${LLVM_PROJ_BUILD}/bin/mlir-tblgen)
set(MLIR_TABLEGEN_EXE mlir-tblgen) set(MLIR_TABLEGEN_EXE mlir-tblgen)

View File

@ -71,7 +71,7 @@ struct OnnxOnnfSymbolMapping {
* @param name onnx tensor name. * @param name onnx tensor name.
* @return onnf tensor corresponding to `name`. * @return onnf tensor corresponding to `name`.
*/ */
mlir::Value *GetTensorByOnnxName(std::string name) { mlir::Value GetTensorByOnnxName(std::string name) {
assert(onnx_name2onnf_tensor.find(legalize_name(name)) != assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
onnx_name2onnf_tensor.end() && onnx_name2onnf_tensor.end() &&
"Tensor not found"); "Tensor not found");
@ -81,9 +81,9 @@ struct OnnxOnnfSymbolMapping {
/*! /*!
* Add a new mapping from onnx tensor name to MLIR symbol. * Add a new mapping from onnx tensor name to MLIR symbol.
* @param name onnx tensor name. * @param name onnx tensor name.
* @param tensor MLIR Value* pointer. * @param tensor MLIR Value pointer.
*/ */
void AddMapping(std::string name, mlir::Value *tensor) { void AddMapping(std::string name, mlir::Value tensor) {
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 && assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
"Tensor already exists."); "Tensor already exists.");
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
@ -97,7 +97,7 @@ private:
/*! /*!
* mapping from onnx tensor names to MLIR tensor. * mapping from onnx tensor names to MLIR tensor.
*/ */
std::map<std::string, mlir::Value*> onnx_name2onnf_tensor; std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
}; };
class FrontendGenImpl { class FrontendGenImpl {
@ -192,13 +192,13 @@ private:
/*! /*!
* Import a input tensor symbol by recording a new entry in frontend_symbols_ * Import a input tensor symbol by recording a new entry in frontend_symbols_
* recording the mapping between legalized onnx tensor name and mlir::Value* * recording the mapping between legalized onnx tensor name and mlir::Value
* for further lookup in computation node importing. * for further lookup in computation node importing.
* @param input onnx input tensor ValueInfoProto. * @param input onnx input tensor ValueInfoProto.
* @param symbol mlir input argument. * @param symbol mlir input argument.
*/ */
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
mlir::Value *symbol) { mlir::Value symbol) {
auto input_tensor_legalized_name = legalize_name(input.name()); auto input_tensor_legalized_name = legalize_name(input.name());
assert( assert(
!frontend_symbols_.ContainKey(input_tensor_legalized_name) && !frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
@ -480,7 +480,7 @@ private:
} }
void ImportNodeGeneric(onnx::NodeProto node) { void ImportNodeGeneric(onnx::NodeProto node) {
std::vector<mlir::Value *> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
@ -515,7 +515,7 @@ private:
onnx::NodeProto node, int nIn, int nOut, onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) { attrs) {
std::vector<mlir::Value *> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
@ -562,7 +562,7 @@ private:
onnx::NodeProto node, int nIn, int nOut, onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) { attrs) {
std::vector<mlir::Value *> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
@ -633,7 +633,7 @@ private:
} }
void ImportNode(onnx::NodeProto node) { void ImportNode(onnx::NodeProto node) {
std::vector<mlir::Value *> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
@ -662,17 +662,17 @@ private:
* Import output tensor, by doing the following: * Import output tensor, by doing the following:
* - Add the type of this output tensor to a list of tensor * - Add the type of this output tensor to a list of tensor
* types representing return types of this graph function. * types representing return types of this graph function.
* - Add this output tensor to the list of mlir::Value* * - Add this output tensor to the list of mlir::Value
* to be returned by the function representing computation graph. * to be returned by the function representing computation graph.
* @param output onnx output tensor ValueInfoProto. * @param output onnx output tensor ValueInfoProto.
* @param ret_types a vector of tensor types representing graph's * @param ret_types a vector of tensor types representing graph's
* output tensor types. * output tensor types.
* @param ret_vals a vector of mlir Value* representing graph's * @param ret_vals a vector of mlir Value representing graph's
* output tensor. * output tensor.
*/ */
void ImportOutputTensor(const onnx::ValueInfoProto &output, void ImportOutputTensor(const onnx::ValueInfoProto &output,
llvm::SmallVectorImpl<mlir::Type> &ret_types, llvm::SmallVectorImpl<mlir::Type> &ret_types,
llvm::SmallVectorImpl<mlir::Value *> &ret_vals) { llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
auto output_tensor_legalized_name = legalize_name(output.name()); auto output_tensor_legalized_name = legalize_name(output.name());
assert( assert(
frontend_symbols_.ContainKey(output_tensor_legalized_name) && frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
@ -722,7 +722,7 @@ private:
} }
llvm::SmallVector<mlir::Type, 4> ret_types; llvm::SmallVector<mlir::Type, 4> ret_types;
llvm::SmallVector<mlir::Value *, 4> ret_vals; llvm::SmallVector<mlir::Value, 4> ret_vals;
// Import the output tensors // Import the output tensors
for (const auto &output : graph.output()) { for (const auto &output : graph.output()) {
ImportOutputTensor(output, ret_types, ret_vals); ImportOutputTensor(output, ret_types, ret_vals);

View File

@ -9,8 +9,9 @@ namespace onnf {
using namespace mlir; using namespace mlir;
ParseResult KrnlDialectOperandParser::ParseOptionalOperand( ParseResult
const Type& operandType, Value*& operand) { KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
Value &operand) {
// If operand queue is empty, parse more operands and cache them. // If operand queue is empty, parse more operands and cache them.
if (_operandRefQueue.empty()) { if (_operandRefQueue.empty()) {
// Parse operand types: // Parse operand types:
@ -27,7 +28,7 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
auto operand_ref = _operandRefQueue.front(); auto operand_ref = _operandRefQueue.front();
_operandRefQueue.pop(); _operandRefQueue.pop();
llvm::SmallVector<Value*, 1> operands; llvm::SmallVector<Value, 1> operands;
_parser.resolveOperand(operand_ref, operandType, operands); _parser.resolveOperand(operand_ref, operandType, operands);
operand = operands.front(); operand = operands.front();
return success(); return success();
@ -38,8 +39,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
} }
ParseResult KrnlDialectOperandParser::ParseOptionalOperand( ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) { const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) {
Value* operand = nullptr; Value operand = nullptr;
if (ParseOptionalOperand(operandType, operand)) if (ParseOptionalOperand(operandType, operand))
return failure(); return failure();
@ -47,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
return success(); return success();
} }
ParseResult KrnlDialectOperandParser::ParseOperand( ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType,
const Type& operandType, Value*& operand) { Value &operand) {
if (ParseOptionalOperand(operandType, operand)) if (ParseOptionalOperand(operandType, operand))
return _parser.emitError( return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand."); _parser.getCurrentLocation(), "Expecting an operand.");
@ -56,7 +57,7 @@ ParseResult KrnlDialectOperandParser::ParseOperand(
} }
ParseResult KrnlDialectOperandParser::ParseOperand( ParseResult KrnlDialectOperandParser::ParseOperand(
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) { const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) {
if (ParseOptionalOperand(operandType, operandList)) if (ParseOptionalOperand(operandType, operandList))
return _parser.emitError( return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand."); _parser.getCurrentLocation(), "Expecting an operand.");
@ -129,7 +130,7 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
boundMaps.emplace_back(AffineMapAttr::get(map)); boundMaps.emplace_back(AffineMapAttr::get(map));
} }
void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) { void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
if (boundMaps.size() % 2 == 0) if (boundMaps.size() % 2 == 0)
_operands.emplace_back(inputLoops[boundMaps.size() / 2]); _operands.emplace_back(inputLoops[boundMaps.size() / 2]);
AffineMap map = builder.getSymbolIdentityMap(); AffineMap map = builder.getSymbolIdentityMap();

View File

@ -17,20 +17,22 @@ class KrnlDialectOperandParser {
: _parser(parser), _builder(parser.getBuilder()){}; : _parser(parser), _builder(parser.getBuilder()){};
// Parse an optional operand. // Parse an optional operand.
mlir::ParseResult ParseOptionalOperand( mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
const mlir::Type& operandType, mlir::Value*& operand); mlir::Value &operand);
// Parse an optional operand and push it to an operand list. // Parse an optional operand and push it to an operand list.
mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType, mlir::ParseResult
llvm::SmallVectorImpl<mlir::Value*>& operandList); ParseOptionalOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList);
// Parse a required operand. // Parse a required operand.
mlir::ParseResult ParseOperand( mlir::ParseResult ParseOperand(const mlir::Type &operandType,
const mlir::Type& operandType, mlir::Value*& operand); mlir::Value &operand);
// Parse a required operand and push it to an operand list. // Parse a required operand and push it to an operand list.
mlir::ParseResult ParseOperand(const mlir::Type& operandType, mlir::ParseResult
llvm::SmallVectorImpl<mlir::Value*>& operandList); ParseOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList);
// Do we have more operands to parse? // Do we have more operands to parse?
bool hasOperandLeft() { return !_operandRefQueue.empty(); } bool hasOperandLeft() { return !_operandRefQueue.empty(); }
@ -63,11 +65,10 @@ void printBound(mlir::AffineMapAttr boundMap,
namespace mlir { namespace mlir {
struct KrnlIterateOperandPack { struct KrnlIterateOperandPack {
KrnlIterateOperandPack(mlir::Builder& builder, KrnlIterateOperandPack(mlir::Builder &builder,
llvm::ArrayRef<mlir::Value*> inputLoops, llvm::ArrayRef<mlir::Value> inputLoops,
llvm::ArrayRef<mlir::Value*> optimizedLoops) llvm::ArrayRef<mlir::Value> optimizedLoops)
: builder(builder), : builder(builder), inputLoops(inputLoops),
inputLoops(inputLoops),
optimizedLoops(optimizedLoops) { optimizedLoops(optimizedLoops) {
_operands.insert( _operands.insert(
_operands.end(), optimizedLoops.begin(), optimizedLoops.end()); _operands.end(), optimizedLoops.begin(), optimizedLoops.end());
@ -75,9 +76,9 @@ struct KrnlIterateOperandPack {
void pushConstantBound(int64_t bound); void pushConstantBound(int64_t bound);
void pushOperandBound(mlir::Value* operand); void pushOperandBound(mlir::Value operand);
llvm::SmallVector<mlir::Value*, 8> getOperands() const { return _operands; } llvm::SmallVector<mlir::Value, 8> getOperands() const { return _operands; }
mlir::ArrayAttr getAttributes() const { mlir::ArrayAttr getAttributes() const {
return builder.getArrayAttr(boundMaps); return builder.getArrayAttr(boundMaps);
@ -90,11 +91,11 @@ struct KrnlIterateOperandPack {
private: private:
int _boundIdx = 0; int _boundIdx = 0;
llvm::SmallVector<mlir::Value*, 8> _operands; llvm::SmallVector<mlir::Value, 8> _operands;
llvm::SmallVector<mlir::Attribute, 8> boundMaps; llvm::SmallVector<mlir::Attribute, 8> boundMaps;
llvm::ArrayRef<mlir::Value*> inputLoops, optimizedLoops; llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
mlir::Builder& builder; mlir::Builder& builder;
}; };

View File

@ -44,21 +44,21 @@ static MemRefType convertTensorToMemRef(TensorType type) {
} }
/// Insert an allocation and deallocation for the given MemRefType. /// Insert an allocation and deallocation for the given MemRefType.
static Value *insertAllocAndDealloc(MemRefType type, Location loc, static Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter, PatternRewriter &rewriter,
bool insertDealloc, bool insertDealloc,
ArrayRef<Value *> operands = {}) { ArrayRef<Value> operands = {}) {
// Put together alloc operands for any dynamic dimensions of the memref. // Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc; AllocOp alloc;
if (!operands.empty()) { if (!operands.empty()) {
auto memRefShape = type.getShape(); auto memRefShape = type.getShape();
auto rank = memRefShape.size(); auto rank = memRefShape.size();
std::map<int, Value *> fromOperands; std::map<int, Value> fromOperands;
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
int memRefDimIdx = rank - 1 - reversedIdx; int memRefDimIdx = rank - 1 - reversedIdx;
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
Value *maxDim = nullptr; Value maxDim = nullptr;
for (int i = 0; i < operands.size(); i++) { for (int i = 0; i < operands.size(); i++) {
auto operandShape = auto operandShape =
operands[i]->getType().cast<MemRefType>().getShape(); operands[i]->getType().cast<MemRefType>().getShape();
@ -85,7 +85,7 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc,
} }
} }
SmallVector<Value *, 4> allocOperands; SmallVector<Value, 4> allocOperands;
for (int i = 0; i < rank; ++i) for (int i = 0; i < rank; ++i)
if (memRefShape[i] < 0) if (memRefShape[i] < 0)
allocOperands.push_back(fromOperands[i]); allocOperands.push_back(fromOperands[i]);
@ -146,14 +146,14 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
// Get run-time dimension information for unknown dimensions used for // Get run-time dimension information for unknown dimensions used for
// broadcasting. // broadcasting.
std::map<int, std::map<int, Value *>> std::map<int, std::map<int, Value>>
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
MemRefType memRefType, ArrayRef<Value *> operands) { MemRefType memRefType, ArrayRef<Value> operands) {
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size(); int64_t rank = memRefShape.size();
// For unknown dimensions, we need to get dimension values at runtime in // For unknown dimensions, we need to get dimension values at runtime in
// order to do broadcasting. // order to do broadcasting.
std::map<int, std::map<int, Value *>> DimInfo; std::map<int, std::map<int, Value>> DimInfo;
// For each result dimension, compute the number of sharing operands. // For each result dimension, compute the number of sharing operands.
// Sharing operands are operands sharing the same index (counting from the // Sharing operands are operands sharing the same index (counting from the
// rightmost to the leftmost) for a given dimension. // rightmost to the leftmost) for a given dimension.
@ -173,7 +173,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
// We only care about unknown dimensions whose number of sharing operands is // We only care about unknown dimensions whose number of sharing operands is
// more than one, since they are potentially broadcasted dimensions. // more than one, since they are potentially broadcasted dimensions.
for (int i = 0; i < operands.size(); ++i) { for (int i = 0; i < operands.size(); ++i) {
std::map<int, Value *> broadcastedDims; std::map<int, Value> broadcastedDims;
auto shape = operands[i]->getType().cast<MemRefType>().getShape(); auto shape = operands[i]->getType().cast<MemRefType>().getShape();
int size = shape.size(); int size = shape.size();
for (int j = 0; j < shape.size(); ++j) { for (int j = 0; j < shape.size(); ++j) {
@ -192,17 +192,17 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
// Extract induction variables that are used for broadcasting values of a // Extract induction variables that are used for broadcasting values of a
// given operand. // given operand.
std::vector<Value *> std::vector<Value>
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value *> loopIVs, Value *operand, ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value *> broadcastedDims) { std::map<int, Value> broadcastedDims) {
// `operand` must has a ranked type. This should have been checked by the // `operand` must has a ranked type. This should have been checked by the
// shape inference pass. // shape inference pass.
auto operandShape = operand->getType().cast<MemRefType>().getShape(); auto operandShape = operand->getType().cast<MemRefType>().getShape();
auto rank = operandShape.size(); auto rank = operandShape.size();
auto loopCount = loopIVs.size(); auto loopCount = loopIVs.size();
std::vector<Value *> newLoopIVs; std::vector<Value> newLoopIVs;
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
auto dimIdx = rank - 1 - reversedIdx; auto dimIdx = rank - 1 - reversedIdx;
auto loopIdx = loopCount - 1 - reversedIdx; auto loopIdx = loopCount - 1 - reversedIdx;
@ -247,7 +247,7 @@ struct ScalarOp<ONNXMulOp> {
template <> template <>
struct ScalarOp<ONNXDivOp> { struct ScalarOp<ONNXDivOp> {
using FOp = DivFOp; using FOp = DivFOp;
using IOp = DivISOp; using IOp = SignedDivIOp;
}; };
template <> template <>
@ -295,9 +295,9 @@ using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp;
// Scalar unary ops for lowering to Krnl dialect. // Scalar unary ops for lowering to Krnl dialect.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <typename UnaryOp> template <typename UnaryOp>
Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types, Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
/* Lower UnaryOp to Ops in the Standard dialect. /* Lower UnaryOp to Ops in the Standard dialect.
*/ */
auto loc = op->getLoc(); auto loc = op->getLoc();
@ -318,14 +318,13 @@ Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXTanhOp // Scalar unary ops for lowering ONNXTanhOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op, Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Type> result_types, ArrayRef<Value> operands,
ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter &rewriter) {
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *operand = operands[0]; Value operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto neg = rewriter.create<SubFOp>(loc, zero, operand); auto neg = rewriter.create<SubFOp>(loc, zero, operand);
@ -342,14 +341,13 @@ Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op,
// Scalar unary ops for lowering ONNXSinhOp // Scalar unary ops for lowering ONNXSinhOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op, Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Type> result_types, ArrayRef<Value> operands,
ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter &rewriter) {
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2) // ConstantOp 2)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *operand = operands[0]; Value operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
@ -366,14 +364,13 @@ Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op,
// Scalar unary ops for lowering ONNXCoshOp // Scalar unary ops for lowering ONNXCoshOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op, Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Type> result_types, ArrayRef<Value> operands,
ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter &rewriter) {
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2) // ConstantOp 2)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *operand = operands[0]; Value operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
@ -390,14 +387,14 @@ Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op,
// Scalar unary ops for lowering ONNXSigmoidOp // Scalar unary ops for lowering ONNXSigmoidOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op, Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value *> operands, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *operand = operands[0]; Value operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
@ -413,8 +410,8 @@ Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
// Scalar unary ops for lowering ONNXHardSigmoidOp // Scalar unary ops for lowering ONNXHardSigmoidOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value *mapToLowerScalarOp<ONNXHardSigmoidOp>( Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value *> operands, Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
// %Y = AddFOp(MulFOp(alpha, %X), beta) // %Y = AddFOp(MulFOp(alpha, %X), beta)
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
@ -424,7 +421,7 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
// %Z, // %Z,
// Constant 1) // Constant 1)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta"); auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
@ -449,14 +446,14 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
// Scalar unary ops for lowering ONNXEluOp // Scalar unary ops for lowering ONNXEluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types, Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)), // MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
// %X) // %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
@ -478,15 +475,14 @@ Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXReluOp // Scalar unary ops for lowering ONNXReluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op, Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Type> result_types, ArrayRef<Value> operands,
ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter &rewriter) {
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// ConstantOp 0, // ConstantOp 0,
// %X) // %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *operand = operands[0]; Value operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto lessThanZero = auto lessThanZero =
@ -500,15 +496,15 @@ Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op,
// Scalar unary ops for lowering ONNXLeakyReluOp // Scalar unary ops for lowering ONNXLeakyReluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value * Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value *> operands, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, %X), // MulFOp(alpha, %X),
// %X) // %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
@ -525,17 +521,16 @@ mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXSeluOp // Scalar unary ops for lowering ONNXSeluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op, Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Type> result_types, ArrayRef<Value> operands,
ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter &rewriter) {
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
// MulFOp(gamma, %X), // MulFOp(gamma, %X),
// MulFOp(gamma, // MulFOp(gamma,
// SubFOp(MulFOp(alpha, ExpOp(%X)), // SubFOp(MulFOp(alpha, ExpOp(%X)),
// alpha))) // alpha)))
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma"); auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
@ -558,13 +553,12 @@ Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op,
// Scalar unary ops for lowering ONNXReciprocalOp // Scalar unary ops for lowering ONNXReciprocalOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value * Value mapToLowerScalarOp<ONNXReciprocalOp>(
mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types, Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter &rewriter) {
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *operand = operands[0]; Value operand = operands[0];
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto result = rewriter.create<DivFOp>(loc, one, operand); auto result = rewriter.create<DivFOp>(loc, one, operand);
@ -576,15 +570,15 @@ mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXMaxOp // Scalar unary ops for lowering ONNXMaxOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types, Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
// %X, // %X,
// %Y) // %Y)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *lhs = operands[0]; Value lhs = operands[0];
Value *rhs = operands[1]; Value rhs = operands[1];
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs); auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs); auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
return result; return result;
@ -594,15 +588,15 @@ Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXMinOp // Scalar unary ops for lowering ONNXMinOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value *mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types, Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
// %X, // %X,
// %Y) // %Y)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value *lhs = operands[0]; Value lhs = operands[0];
Value *rhs = operands[1]; Value rhs = operands[1];
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs); auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs); auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
return result; return result;
@ -615,7 +609,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid. // TODO: Check that the types are valid.
// An element-wise unary operation must have all operands and the result of // An element-wise unary operation must have all operands and the result of
@ -632,7 +626,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// dimensions with the result at this pre-optimization phase. // dimensions with the result at this pre-optimization phase.
// TODO: verify that dimensions match. // TODO: verify that dimensions match.
// TODO: can the dimension of the result differ after optimizations? // TODO: can the dimension of the result differ after optimizations?
Value *alloc; Value alloc;
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
@ -647,7 +641,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// Define loops. // Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank); auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value *> originalLoops; std::vector<Value> originalLoops;
originalLoops.reserve(rank); originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) { for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result); originalLoops.push_back(result);
@ -655,7 +649,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// Define loop optimization. // Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank); auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value *> optimizedLoops; std::vector<Value> optimizedLoops;
optimizedLoops.reserve(rank); optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) { for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result); optimizedLoops.push_back(result);
@ -695,7 +689,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(&iterationBlock); rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation: // Handle the operation:
SmallVector<Value *, 4> loopIVs; SmallVector<Value, 4> loopIVs;
for (auto arg : iterationBlock.getArguments()) for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg); loopIVs.push_back(arg);
@ -718,7 +712,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
PatternMatchResult PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid. // TODO: Check that the types are valid.
// An element-wise variadic operation must have all operands and the result // An element-wise variadic operation must have all operands and the result
@ -730,7 +724,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType); auto memRefType = convertTensorToMemRef(tensorType);
Value *alloc; Value alloc;
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);
// If the output has a dynamic dimension, we compute its dimension at // If the output has a dynamic dimension, we compute its dimension at
// runtime by using dimensions from the operands. // runtime by using dimensions from the operands.
@ -749,7 +743,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Define loops. // Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank); auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value *> originalLoops; std::vector<Value> originalLoops;
originalLoops.reserve(rank); originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) { for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result); originalLoops.push_back(result);
@ -757,7 +751,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Define loop optimization. // Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank); auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value *> optimizedLoops; std::vector<Value> optimizedLoops;
optimizedLoops.reserve(rank); optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) { for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result); optimizedLoops.push_back(result);
@ -781,7 +775,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Get run-time dimension information for unknown dimensions used for // Get run-time dimension information for unknown dimensions used for
// broadcasting. // broadcasting.
std::map<int, std::map<int, Value *>> broadcastedDimInfo = std::map<int, std::map<int, Value>> broadcastedDimInfo =
getBroadcastedDimInfo(loc, rewriter, memRefType, operands); getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack); auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
@ -801,12 +795,12 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(&iterationBlock); rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation: // Handle the operation:
SmallVector<Value *, 4> loopIVs; SmallVector<Value, 4> loopIVs;
for (auto arg : iterationBlock.getArguments()) for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg); loopIVs.push_back(arg);
// Fold over operands for each of their scalar values // Fold over operands for each of their scalar values
Value *accumulated, *next; Value accumulated, next;
auto accumulatedLoopIVs = getLoopIVsForBroadcasting( auto accumulatedLoopIVs = getLoopIVsForBroadcasting(
loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]);
accumulated = rewriter.create<LoadOp>(loc, operands[0], accumulatedLoopIVs); accumulated = rewriter.create<LoadOp>(loc, operands[0], accumulatedLoopIVs);
@ -831,17 +825,17 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
PatternMatchResult PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>(); auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType); auto memRefType = convertTensorToMemRef(tensorType);
Value *alloc; Value alloc;
// Compute size in bytes. // Compute size in bytes.
Value *tensorSize = rewriter.create<ConstantOp>( Value tensorSize = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
getMemRefEltSizeInBytes(memRefType))); getMemRefEltSizeInBytes(memRefType)));
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);
@ -849,14 +843,14 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
} else { } else {
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
SmallVector<Value *, 4> allocOperands; SmallVector<Value, 4> allocOperands;
for (int i = 0; i < memRefShape.size(); ++i) { for (int i = 0; i < memRefShape.size(); ++i) {
// The shape array can always be used to construct shape information of // The shape array can always be used to construct shape information of
// the result. // the result.
Value *index = rewriter.create<ConstantOp>( Value index = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
Value *loadedVal = rewriter.create<LoadOp>(loc, operands[1], index); Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
Value *int64LoadedVal = rewriter.create<ZeroExtendIOp>( Value int64LoadedVal = rewriter.create<ZeroExtendIOp>(
loc, loadedVal, rewriter.getIntegerType(64)); loc, loadedVal, rewriter.getIntegerType(64));
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal); tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
allocOperands.push_back(rewriter.create<IndexCastOp>( allocOperands.push_back(rewriter.create<IndexCastOp>(

View File

@ -30,7 +30,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
operandItr++; operandItr++;
// Organize operands into lower/upper bounds in affine.for ready formats. // Organize operands into lower/upper bounds in affine.for ready formats.
SmallVector<Value *, 4> lbOperands, ubOperands; SmallVector<Value, 4> lbOperands, ubOperands;
AffineMap lbMap, ubMap; AffineMap lbMap, ubMap;
for (int boundType = 0; boundType < 2; boundType++) { for (int boundType = 0; boundType < 2; boundType++) {
auto &operands = boundType == 0 ? lbOperands : ubOperands; auto &operands = boundType == 0 ? lbOperands : ubOperands;

View File

@ -51,7 +51,7 @@ public:
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
PatternMatchResult PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext(); auto *context = op->getContext();
auto loc = op->getLoc(); auto loc = op->getLoc();
@ -66,27 +66,27 @@ public:
// First operand. // First operand.
Type dstType = Type dstType =
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1); operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value *alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>( Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
Value *alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>( Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory); loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
// Second operand. // Second operand.
Type srcType = Type srcType =
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1); operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value *alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>( Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
Value *alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>( Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory); loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
// Size. // Size.
Value *int64Size = rewriter.create<LLVM::SExtOp>( Value int64Size = rewriter.create<LLVM::SExtOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]); loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
// Memcpy call // Memcpy call
rewriter.create<CallOp>( rewriter.create<CallOp>(
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value *>( ArrayRef<Value>(
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size})); {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
rewriter.eraseOp(op); rewriter.eraseOp(op);
@ -210,7 +210,7 @@ public:
// Retrieve dynamic mem refs from wrapped input, and convert every one of // Retrieve dynamic mem refs from wrapped input, and convert every one of
// them to static mem refs. // them to static mem refs.
SmallVector<Value *, 4> staticInputs; SmallVector<Value, 4> staticInputs;
auto wrappedInput = entryPointEntryBlock.getArgument(0); auto wrappedInput = entryPointEntryBlock.getArgument(0);
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) { for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
// Call API function to retrieve the i-th dynamic memref. // Call API function to retrieve the i-th dynamic memref.
@ -225,13 +225,12 @@ public:
auto memRefTy = memRefPtrTy.getPointerElementTy(); auto memRefTy = memRefPtrTy.getPointerElementTy();
auto one = rewriter.create<LLVM::ConstantOp>( auto one = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(1)); loc, int32Ty, rewriter.getI32IntegerAttr(1));
Value *ptrToMemRef = Value ptrToMemRef = rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one, /*alignment=*/0);
/*alignment=*/0);
// Fill in the memref underlying ptrToMemRef with information extracted // Fill in the memref underlying ptrToMemRef with information extracted
// from dynMemRef. // from dynMemRef.
fillPtrToMemRefWithDynMemRef(*dynMemRef, *ptrToMemRef, rewriter, loc, fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc,
apiRegistry, llvmDialect); apiRegistry, llvmDialect);
// ptrToMemRef will be an input to main computation graph function. // ptrToMemRef will be an input to main computation graph function.
@ -261,8 +260,8 @@ public:
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
auto outDynMemRef = callApi(rewriter, loc, apiRegistry, auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
fillDynMemRefWithMemRef(*outMemRef, *outDynMemRef, rewriter, loc, fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry,
apiRegistry, llvmDialect); llvmDialect);
auto zero = rewriter.create<LLVM::ConstantOp>( auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(0)); loc, int32Ty, rewriter.getI32IntegerAttr(0));
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
@ -270,7 +269,7 @@ public:
// Return wrapped output. // Return wrapped output.
rewriter.create<LLVM::ReturnOp>(loc, rewriter.create<LLVM::ReturnOp>(loc,
SmallVector<Value *, 1>({wrappedOutput})); SmallVector<Value, 1>({wrappedOutput}));
return matchSuccess(); return matchSuccess();
} }
@ -315,11 +314,11 @@ private:
// Call a registered API, return the return SSA values if only one result is // Call a registered API, return the return SSA values if only one result is
// returned, otherwise return nullptr. // returned, otherwise return nullptr.
Value *callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry, Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
API apiId, ArrayRef<Value *> params) const { API apiId, ArrayRef<Value> params) const {
auto returnVals = rewriter.create<LLVM::CallOp>( auto returnVals = rewriter.create<LLVM::CallOp>(
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef, loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
ArrayRef<Value *>(params)); ArrayRef<Value>(params));
if (returnVals.getNumResults() == 1) if (returnVals.getNumResults() == 1)
return returnVals.getResult(0); return returnVals.getResult(0);
return nullptr; return nullptr;
@ -348,12 +347,11 @@ private:
auto memRefTy = memRefPtrTy.getPointerElementTy(); auto memRefTy = memRefPtrTy.getPointerElementTy();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
Value *memRef = Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, ptrToMemRef);
rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, &ptrToMemRef);
// Set dataPtr and alignedDataPtr; // Set dataPtr and alignedDataPtr;
auto dataPtr = auto dataPtr =
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {&dynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
dataPtr = rewriter.create<LLVM::BitcastOp>( dataPtr = rewriter.create<LLVM::BitcastOp>(
loc, memRefTy.getStructElementType(0), dataPtr); loc, memRefTy.getStructElementType(0), dataPtr);
memRef = rewriter.create<LLVM::InsertValueOp>( memRef = rewriter.create<LLVM::InsertValueOp>(
@ -373,9 +371,9 @@ private:
// Get rank, sizes array ptr and strides array ptr. // Get rank, sizes array ptr and strides array ptr.
auto rank = memRefTy.getStructElementType(3).getArrayNumElements(); auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
auto sizesArrayPtr = auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&dynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
auto stridesArrayPtr = auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&dynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {dynMemRef});
for (decltype(rank) i = 0; i < rank; i++) { for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>( auto dimIdx = rewriter.create<LLVM::ConstantOp>(
@ -384,7 +382,7 @@ private:
// Insert size of the dimension. // Insert size of the dimension.
auto dimSizePtr = rewriter.create<LLVM::GEPOp>( auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr, loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx})); ArrayRef<Value>({dimIdx}));
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(), auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
dimSizePtr); dimSizePtr);
memRef = rewriter.create<LLVM::InsertValueOp>( memRef = rewriter.create<LLVM::InsertValueOp>(
@ -395,7 +393,7 @@ private:
// Insert stride of the dimension. // Insert stride of the dimension.
auto dimStridePtr = rewriter.create<LLVM::GEPOp>( auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr, loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx})); ArrayRef<Value>({dimIdx}));
auto dimStride = rewriter.create<LLVM::LoadOp>( auto dimStride = rewriter.create<LLVM::LoadOp>(
loc, int64Ty.getPointerTo(), dimStridePtr); loc, int64Ty.getPointerTo(), dimStridePtr);
memRef = rewriter.create<LLVM::InsertValueOp>( memRef = rewriter.create<LLVM::InsertValueOp>(
@ -404,7 +402,7 @@ private:
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
} }
rewriter.create<LLVM::StoreOp>(loc, memRef, &ptrToMemRef); rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
} }
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef, void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
@ -415,19 +413,19 @@ private:
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
// Extract the data pointer, and record it in dynamic mem ref created. // Extract the data pointer, and record it in dynamic mem ref created.
Value *outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>( Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, outMemRefTy.getStructElementType(0), &outMemRef, loc, outMemRefTy.getStructElementType(0), outMemRef,
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)})); rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>( outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr); loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
callApi(rewriter, loc, apiRegistry, API::SET_DATA, callApi(rewriter, loc, apiRegistry, API::SET_DATA,
{&outDynMemRef, outMemRefDataPtr}); {outDynMemRef, outMemRefDataPtr});
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements(); auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
auto sizesArrayPtr = auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&outDynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
auto stridesArrayPtr = auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&outDynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outDynMemRef});
for (decltype(rank) i = 0; i < rank; i++) { for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>( auto dimIdx = rewriter.create<LLVM::ConstantOp>(
@ -435,22 +433,22 @@ private:
// Transfer size of dimension from memref to dynamic memref. // Transfer size of dimension from memref to dynamic memref.
auto dimSize = rewriter.create<LLVM::ExtractValueOp>( auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, &outMemRef, loc, int64Ty, outMemRef,
rewriter.getArrayAttr( rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)})); {rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
auto dimSizePtr = rewriter.create<LLVM::GEPOp>( auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr, loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx})); ArrayRef<Value>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr); rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
// Transfer stride of dimension from memref to dynamic memref. // Transfer stride of dimension from memref to dynamic memref.
auto dimStride = rewriter.create<LLVM::ExtractValueOp>( auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, &outMemRef, loc, int64Ty, outMemRef,
rewriter.getArrayAttr( rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
auto dimStridePtr = rewriter.create<LLVM::GEPOp>( auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), stridesArrayPtr, loc, int64Ty.getPointerTo(), stridesArrayPtr,
ArrayRef<Value *>({dimIdx})); ArrayRef<Value>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr); rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
} }
} }

View File

@ -1,5 +1,5 @@
set(LLVM_LIT ${LLVM_SRC}/utils/lit/lit.py) set(LLVM_LIT ${LLVM_PROJ_SRC}/utils/lit/lit.py)
set(LLVM_DEFAULT_EXTERNAL_LIT ${LLVM_BUILD}/bin/llvm-lit) set(LLVM_DEFAULT_EXTERNAL_LIT ${LLVM_PROJ_BUILD}/bin/llvm-lit)
configure_lit_site_cfg(${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in configure_lit_site_cfg(${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py

View File

@ -2,7 +2,7 @@
import lit.llvm import lit.llvm
config.llvm_tools_dir = "@MLIR_TOOLS_DIR@" config.llvm_tools_dir = "@MLIR_TOOLS_DIR@"
config.mlir_obj_root = "@LLVM_BUILD@" config.mlir_obj_root = "@LLVM_PROJ_BUILD@"
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@" config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
config.suffixes = ['.mlir'] config.suffixes = ['.mlir']