Merge pull request #9 from clang-ykt/use-mlir-in-llvm-project
Update to latest MLIR
This commit is contained in:
commit
322002f509
|
@ -7,52 +7,33 @@ jobs:
|
|||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: "Pull Submodules"
|
||||
name: Pull Submodules
|
||||
command: |
|
||||
git submodule update --init --recursive
|
||||
- run:
|
||||
name: Check current directory
|
||||
command: pwd
|
||||
- run:
|
||||
name: Check current directory content
|
||||
command: ls
|
||||
- run:
|
||||
name: Installing GCC
|
||||
command: 'sudo apt-get update && sudo apt-get install -y gcc g++'
|
||||
- run:
|
||||
name: Install CMAKE
|
||||
command: 'sudo apt-get update && sudo apt-get install -y cmake ninja-build'
|
||||
- run:
|
||||
name: Install Protobuf
|
||||
command: 'sudo apt-get update && sudo apt-get install -y protobuf-compiler'
|
||||
- run:
|
||||
name: Check gcc version
|
||||
command: gcc --version
|
||||
|
||||
name: Installing GCC, CMake, Ninja, Protobuf
|
||||
command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler
|
||||
# Use cached mlir installation if possible.
|
||||
- restore_cache:
|
||||
key: ONNF-MLIR-{{ arch }}
|
||||
key: V2-LLVM-PROJECT-{{ arch }}
|
||||
- run:
|
||||
name: Install MLIR
|
||||
command: |
|
||||
# Check whether cache restoration succeeds by checking whether
|
||||
# mlir-opt executable exists.
|
||||
if [ ! -f llvm-project/build/bin/mlir-opt ]; then
|
||||
git clone https://github.com/llvm/llvm-project.git
|
||||
cd llvm-project && git checkout 9b6ad8466bb8b97082b705270603ad7f4559e931 && cd ..
|
||||
git clone https://github.com/tensorflow/mlir llvm-project/llvm/projects/mlir
|
||||
cd llvm-project/llvm/projects/mlir && git checkout 0710266d0f56cf6ab0f437badbd7416b6cecdf5f && cd ../../../..
|
||||
mkdir llvm-project/build
|
||||
cd llvm-project/build
|
||||
cmake -G Ninja ../llvm -DLLVM_ENABLE_RTTI=ON -DLLVM_BUILD_EXAMPLES=OFF -DLLVM_TARGETS_TO_BUILD="host" -DCMAKE_BUILD_TYPE=Release
|
||||
CMAKE_EXE_LINKER_FLAGS="-Wl,--reduce-memory-overheads -Wl,--hash-size=512" cmake --build . --target check-mlir -- -j 4
|
||||
export MAKEFLAGS=-j4
|
||||
source .circleci/install-mlir.sh
|
||||
fi
|
||||
- save_cache:
|
||||
key: ONNF-MLIR-{{ arch }}
|
||||
key: V2-LLVM-PROJECT-{{ arch }}
|
||||
paths:
|
||||
- llvm-project
|
||||
- run:
|
||||
name: Install ONNF
|
||||
command: |
|
||||
mkdir build && cd build
|
||||
LLVM_SRC=$(pwd)/../llvm-project/llvm LLVM_BUILD=$(pwd)/../llvm-project/build cmake ..
|
||||
LLVM_PROJ_SRC=$(pwd)/../llvm-project/ LLVM_PROJ_BUILD=$(pwd)/../llvm-project/build cmake ..
|
||||
make all
|
||||
LIT_OPTS=-v make check-mlir-lit
|
||||
- run:
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
git clone https://github.com/llvm/llvm-project.git
|
||||
mkdir llvm-project/build
|
||||
cd llvm-project/build
|
||||
cmake -G Ninja ../llvm \
|
||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||
-DLLVM_BUILD_EXAMPLES=ON \
|
||||
-DLLVM_TARGETS_TO_BUILD="host" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DLLVM_ENABLE_RTTI=ON
|
||||
|
||||
cmake --build . --target check-mlir -- ${MAKEFLAGS}
|
46
MLIR.cmake
46
MLIR.cmake
|
@ -1,38 +1,38 @@
|
|||
# Path to LLVM source folder.
|
||||
if(DEFINED ENV{LLVM_SRC})
|
||||
set(LLVM_SRC $ENV{LLVM_SRC})
|
||||
if(EXISTS ${LLVM_SRC})
|
||||
message(STATUS "LLVM_SRC " ${LLVM_SRC})
|
||||
if(DEFINED ENV{LLVM_PROJ_SRC})
|
||||
set(LLVM_PROJ_SRC $ENV{LLVM_PROJ_SRC})
|
||||
if(EXISTS ${LLVM_PROJ_SRC})
|
||||
message(STATUS "LLVM_PROJ_SRC " ${LLVM_PROJ_SRC})
|
||||
else()
|
||||
message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: "
|
||||
${LLVM_SRC})
|
||||
message(FATAL_ERROR "The path specified by LLVM_PROJ_SRC does not exist: "
|
||||
${LLVM_PROJ_SRC})
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "env variable LLVM_SRC not set")
|
||||
message(FATAL_ERROR "env variable LLVM_PROJ_SRC not set")
|
||||
endif()
|
||||
|
||||
# Path to LLVM build folder
|
||||
if(DEFINED ENV{LLVM_BUILD})
|
||||
set(LLVM_BUILD $ENV{LLVM_BUILD})
|
||||
if(EXISTS ${LLVM_BUILD})
|
||||
message(STATUS "LLVM_BUILD " ${LLVM_BUILD})
|
||||
if(DEFINED ENV{LLVM_PROJ_BUILD})
|
||||
set(LLVM_PROJ_BUILD $ENV{LLVM_PROJ_BUILD})
|
||||
if(EXISTS ${LLVM_PROJ_BUILD})
|
||||
message(STATUS "LLVM_PROJ_BUILD " ${LLVM_PROJ_BUILD})
|
||||
else()
|
||||
message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: "
|
||||
${LLVM_BUILD})
|
||||
message(FATAL_ERROR "The path specified by LLVM_PROJ_BUILD does not exist: "
|
||||
${LLVM_PROJ_BUILD})
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "env variable LLVM_BUILD not set")
|
||||
message(FATAL_ERROR "env variable LLVM_PROJ_BUILD not set")
|
||||
endif()
|
||||
|
||||
# LLVM project lib folder
|
||||
set(LLVM_PROJECT_LIB ${LLVM_BUILD}/lib)
|
||||
set(LLVM_PROJECT_LIB ${LLVM_PROJ_BUILD}/lib)
|
||||
|
||||
# Include paths for MLIR
|
||||
set(LLVM_SRC_INCLUDE_PATH ${LLVM_SRC}/include)
|
||||
set(LLVM_BIN_INCLUDE_PATH ${LLVM_BUILD}/include)
|
||||
set(MLIR_SRC_INCLUDE_PATH ${LLVM_SRC}/projects/mlir/include)
|
||||
set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include)
|
||||
set(MLIR_TOOLS_DIR ${LLVM_BUILD}/bin)
|
||||
set(LLVM_SRC_INCLUDE_PATH ${LLVM_PROJ_SRC}/llvm/include)
|
||||
set(LLVM_BIN_INCLUDE_PATH ${LLVM_PROJ_BUILD}/include)
|
||||
set(MLIR_SRC_INCLUDE_PATH ${LLVM_PROJ_SRC}/mlir/include)
|
||||
set(MLIR_BIN_INCLUDE_PATH ${LLVM_PROJ_BUILD}/tools/mlir/include)
|
||||
set(MLIR_TOOLS_DIR ${LLVM_PROJ_BUILD}/bin)
|
||||
|
||||
set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/bin)
|
||||
set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
|
||||
|
@ -173,7 +173,7 @@ function(whole_archive_link target lib_dir)
|
|||
endfunction(whole_archive_link)
|
||||
|
||||
function(whole_archive_link_mlir target)
|
||||
whole_archive_link(${target} ${LLVM_BUILD}/lib ${ARGN})
|
||||
whole_archive_link(${target} ${LLVM_PROJ_BUILD}/lib ${ARGN})
|
||||
endfunction(whole_archive_link_mlir)
|
||||
|
||||
function(whole_archive_link_onnf target)
|
||||
|
@ -184,7 +184,7 @@ function(whole_archive_link_onnf target)
|
|||
endfunction(whole_archive_link_onnf)
|
||||
|
||||
set(LLVM_CMAKE_DIR
|
||||
"${LLVM_BUILD}/lib/cmake/llvm"
|
||||
"${LLVM_PROJ_BUILD}/lib/cmake/llvm"
|
||||
CACHE PATH "Path to LLVM cmake modules")
|
||||
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
|
||||
include(AddLLVM)
|
||||
|
@ -205,5 +205,5 @@ endfunction()
|
|||
# table gen utility itself can be detected and cause re-compilation of .td file.
|
||||
add_executable(mlir-tblgen IMPORTED)
|
||||
set_property(TARGET mlir-tblgen
|
||||
PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen)
|
||||
PROPERTY IMPORTED_LOCATION ${LLVM_PROJ_BUILD}/bin/mlir-tblgen)
|
||||
set(MLIR_TABLEGEN_EXE mlir-tblgen)
|
||||
|
|
|
@ -71,7 +71,7 @@ struct OnnxOnnfSymbolMapping {
|
|||
* @param name onnx tensor name.
|
||||
* @return onnf tensor corresponding to `name`.
|
||||
*/
|
||||
mlir::Value *GetTensorByOnnxName(std::string name) {
|
||||
mlir::Value GetTensorByOnnxName(std::string name) {
|
||||
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
|
||||
onnx_name2onnf_tensor.end() &&
|
||||
"Tensor not found");
|
||||
|
@ -81,9 +81,9 @@ struct OnnxOnnfSymbolMapping {
|
|||
/*!
|
||||
* Add a new mapping from onnx tensor name to MLIR symbol.
|
||||
* @param name onnx tensor name.
|
||||
* @param tensor MLIR Value* pointer.
|
||||
* @param tensor MLIR Value pointer.
|
||||
*/
|
||||
void AddMapping(std::string name, mlir::Value *tensor) {
|
||||
void AddMapping(std::string name, mlir::Value tensor) {
|
||||
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
|
||||
"Tensor already exists.");
|
||||
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
||||
|
@ -97,7 +97,7 @@ private:
|
|||
/*!
|
||||
* mapping from onnx tensor names to MLIR tensor.
|
||||
*/
|
||||
std::map<std::string, mlir::Value*> onnx_name2onnf_tensor;
|
||||
std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
|
||||
};
|
||||
|
||||
class FrontendGenImpl {
|
||||
|
@ -192,13 +192,13 @@ private:
|
|||
|
||||
/*!
|
||||
* Import a input tensor symbol by recording a new entry in frontend_symbols_
|
||||
* recording the mapping between legalized onnx tensor name and mlir::Value*
|
||||
* recording the mapping between legalized onnx tensor name and mlir::Value
|
||||
* for further lookup in computation node importing.
|
||||
* @param input onnx input tensor ValueInfoProto.
|
||||
* @param symbol mlir input argument.
|
||||
*/
|
||||
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
|
||||
mlir::Value *symbol) {
|
||||
mlir::Value symbol) {
|
||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||
assert(
|
||||
!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
||||
|
@ -480,7 +480,7 @@ private:
|
|||
}
|
||||
|
||||
void ImportNodeGeneric(onnx::NodeProto node) {
|
||||
std::vector<mlir::Value *> inputs;
|
||||
std::vector<mlir::Value> inputs;
|
||||
for (auto item : node.input()) {
|
||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||
|
@ -515,7 +515,7 @@ private:
|
|||
onnx::NodeProto node, int nIn, int nOut,
|
||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||
attrs) {
|
||||
std::vector<mlir::Value *> inputs;
|
||||
std::vector<mlir::Value> inputs;
|
||||
for (auto item : node.input()) {
|
||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||
|
@ -562,7 +562,7 @@ private:
|
|||
onnx::NodeProto node, int nIn, int nOut,
|
||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||
attrs) {
|
||||
std::vector<mlir::Value *> inputs;
|
||||
std::vector<mlir::Value> inputs;
|
||||
for (auto item : node.input()) {
|
||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||
|
@ -633,7 +633,7 @@ private:
|
|||
}
|
||||
|
||||
void ImportNode(onnx::NodeProto node) {
|
||||
std::vector<mlir::Value *> inputs;
|
||||
std::vector<mlir::Value> inputs;
|
||||
for (auto item : node.input()) {
|
||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||
|
@ -662,17 +662,17 @@ private:
|
|||
* Import output tensor, by doing the following:
|
||||
* - Add the type of this output tensor to a list of tensor
|
||||
* types representing return types of this graph function.
|
||||
* - Add this output tensor to the list of mlir::Value*
|
||||
* - Add this output tensor to the list of mlir::Value
|
||||
* to be returned by the function representing computation graph.
|
||||
* @param output onnx output tensor ValueInfoProto.
|
||||
* @param ret_types a vector of tensor types representing graph's
|
||||
* output tensor types.
|
||||
* @param ret_vals a vector of mlir Value* representing graph's
|
||||
* @param ret_vals a vector of mlir Value representing graph's
|
||||
* output tensor.
|
||||
*/
|
||||
void ImportOutputTensor(const onnx::ValueInfoProto &output,
|
||||
llvm::SmallVectorImpl<mlir::Type> &ret_types,
|
||||
llvm::SmallVectorImpl<mlir::Value *> &ret_vals) {
|
||||
llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
|
||||
auto output_tensor_legalized_name = legalize_name(output.name());
|
||||
assert(
|
||||
frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
|
||||
|
@ -722,7 +722,7 @@ private:
|
|||
}
|
||||
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
llvm::SmallVector<mlir::Value *, 4> ret_vals;
|
||||
llvm::SmallVector<mlir::Value, 4> ret_vals;
|
||||
// Import the output tensors
|
||||
for (const auto &output : graph.output()) {
|
||||
ImportOutputTensor(output, ret_types, ret_vals);
|
||||
|
|
|
@ -9,8 +9,9 @@ namespace onnf {
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||
const Type& operandType, Value*& operand) {
|
||||
ParseResult
|
||||
KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
|
||||
Value &operand) {
|
||||
// If operand queue is empty, parse more operands and cache them.
|
||||
if (_operandRefQueue.empty()) {
|
||||
// Parse operand types:
|
||||
|
@ -27,7 +28,7 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
|||
auto operand_ref = _operandRefQueue.front();
|
||||
_operandRefQueue.pop();
|
||||
|
||||
llvm::SmallVector<Value*, 1> operands;
|
||||
llvm::SmallVector<Value, 1> operands;
|
||||
_parser.resolveOperand(operand_ref, operandType, operands);
|
||||
operand = operands.front();
|
||||
return success();
|
||||
|
@ -38,8 +39,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
|||
}
|
||||
|
||||
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
|
||||
Value* operand = nullptr;
|
||||
const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) {
|
||||
Value operand = nullptr;
|
||||
if (ParseOptionalOperand(operandType, operand))
|
||||
return failure();
|
||||
|
||||
|
@ -47,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
|||
return success();
|
||||
}
|
||||
|
||||
ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||
const Type& operandType, Value*& operand) {
|
||||
ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType,
|
||||
Value &operand) {
|
||||
if (ParseOptionalOperand(operandType, operand))
|
||||
return _parser.emitError(
|
||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||
|
@ -56,7 +57,7 @@ ParseResult KrnlDialectOperandParser::ParseOperand(
|
|||
}
|
||||
|
||||
ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
|
||||
const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) {
|
||||
if (ParseOptionalOperand(operandType, operandList))
|
||||
return _parser.emitError(
|
||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||
|
@ -129,7 +130,7 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
|
|||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||
}
|
||||
|
||||
void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) {
|
||||
void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
|
||||
if (boundMaps.size() % 2 == 0)
|
||||
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
||||
AffineMap map = builder.getSymbolIdentityMap();
|
||||
|
|
|
@ -17,20 +17,22 @@ class KrnlDialectOperandParser {
|
|||
: _parser(parser), _builder(parser.getBuilder()){};
|
||||
|
||||
// Parse an optional operand.
|
||||
mlir::ParseResult ParseOptionalOperand(
|
||||
const mlir::Type& operandType, mlir::Value*& operand);
|
||||
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
|
||||
mlir::Value &operand);
|
||||
|
||||
// Parse an optional operand and push it to an operand list.
|
||||
mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType,
|
||||
llvm::SmallVectorImpl<mlir::Value*>& operandList);
|
||||
mlir::ParseResult
|
||||
ParseOptionalOperand(const mlir::Type &operandType,
|
||||
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||
|
||||
// Parse a required operand.
|
||||
mlir::ParseResult ParseOperand(
|
||||
const mlir::Type& operandType, mlir::Value*& operand);
|
||||
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
|
||||
mlir::Value &operand);
|
||||
|
||||
// Parse a required operand and push it to an operand list.
|
||||
mlir::ParseResult ParseOperand(const mlir::Type& operandType,
|
||||
llvm::SmallVectorImpl<mlir::Value*>& operandList);
|
||||
mlir::ParseResult
|
||||
ParseOperand(const mlir::Type &operandType,
|
||||
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||
|
||||
// Do we have more operands to parse?
|
||||
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
|
||||
|
@ -63,11 +65,10 @@ void printBound(mlir::AffineMapAttr boundMap,
|
|||
namespace mlir {
|
||||
|
||||
struct KrnlIterateOperandPack {
|
||||
KrnlIterateOperandPack(mlir::Builder& builder,
|
||||
llvm::ArrayRef<mlir::Value*> inputLoops,
|
||||
llvm::ArrayRef<mlir::Value*> optimizedLoops)
|
||||
: builder(builder),
|
||||
inputLoops(inputLoops),
|
||||
KrnlIterateOperandPack(mlir::Builder &builder,
|
||||
llvm::ArrayRef<mlir::Value> inputLoops,
|
||||
llvm::ArrayRef<mlir::Value> optimizedLoops)
|
||||
: builder(builder), inputLoops(inputLoops),
|
||||
optimizedLoops(optimizedLoops) {
|
||||
_operands.insert(
|
||||
_operands.end(), optimizedLoops.begin(), optimizedLoops.end());
|
||||
|
@ -75,9 +76,9 @@ struct KrnlIterateOperandPack {
|
|||
|
||||
void pushConstantBound(int64_t bound);
|
||||
|
||||
void pushOperandBound(mlir::Value* operand);
|
||||
void pushOperandBound(mlir::Value operand);
|
||||
|
||||
llvm::SmallVector<mlir::Value*, 8> getOperands() const { return _operands; }
|
||||
llvm::SmallVector<mlir::Value, 8> getOperands() const { return _operands; }
|
||||
|
||||
mlir::ArrayAttr getAttributes() const {
|
||||
return builder.getArrayAttr(boundMaps);
|
||||
|
@ -90,11 +91,11 @@ struct KrnlIterateOperandPack {
|
|||
private:
|
||||
int _boundIdx = 0;
|
||||
|
||||
llvm::SmallVector<mlir::Value*, 8> _operands;
|
||||
llvm::SmallVector<mlir::Value, 8> _operands;
|
||||
|
||||
llvm::SmallVector<mlir::Attribute, 8> boundMaps;
|
||||
|
||||
llvm::ArrayRef<mlir::Value*> inputLoops, optimizedLoops;
|
||||
llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
|
||||
|
||||
mlir::Builder& builder;
|
||||
};
|
||||
|
|
|
@ -44,21 +44,21 @@ static MemRefType convertTensorToMemRef(TensorType type) {
|
|||
}
|
||||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
static Value *insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter &rewriter,
|
||||
bool insertDealloc,
|
||||
ArrayRef<Value *> operands = {}) {
|
||||
static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter &rewriter,
|
||||
bool insertDealloc,
|
||||
ArrayRef<Value> operands = {}) {
|
||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||
AllocOp alloc;
|
||||
if (!operands.empty()) {
|
||||
auto memRefShape = type.getShape();
|
||||
auto rank = memRefShape.size();
|
||||
|
||||
std::map<int, Value *> fromOperands;
|
||||
std::map<int, Value> fromOperands;
|
||||
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||
int memRefDimIdx = rank - 1 - reversedIdx;
|
||||
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
|
||||
Value *maxDim = nullptr;
|
||||
Value maxDim = nullptr;
|
||||
for (int i = 0; i < operands.size(); i++) {
|
||||
auto operandShape =
|
||||
operands[i]->getType().cast<MemRefType>().getShape();
|
||||
|
@ -85,7 +85,7 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
}
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> allocOperands;
|
||||
SmallVector<Value, 4> allocOperands;
|
||||
for (int i = 0; i < rank; ++i)
|
||||
if (memRefShape[i] < 0)
|
||||
allocOperands.push_back(fromOperands[i]);
|
||||
|
@ -146,14 +146,14 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
|||
|
||||
// Get run-time dimension information for unknown dimensions used for
|
||||
// broadcasting.
|
||||
std::map<int, std::map<int, Value *>>
|
||||
std::map<int, std::map<int, Value>>
|
||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||
MemRefType memRefType, ArrayRef<Value *> operands) {
|
||||
MemRefType memRefType, ArrayRef<Value> operands) {
|
||||
auto memRefShape = memRefType.getShape();
|
||||
int64_t rank = memRefShape.size();
|
||||
// For unknown dimensions, we need to get dimension values at runtime in
|
||||
// order to do broadcasting.
|
||||
std::map<int, std::map<int, Value *>> DimInfo;
|
||||
std::map<int, std::map<int, Value>> DimInfo;
|
||||
// For each result dimension, compute the number of sharing operands.
|
||||
// Sharing operands are operands sharing the same index (counting from the
|
||||
// rightmost to the leftmost) for a given dimension.
|
||||
|
@ -173,7 +173,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
|||
// We only care about unknown dimensions whose number of sharing operands is
|
||||
// more than one, since they are potentially broadcasted dimensions.
|
||||
for (int i = 0; i < operands.size(); ++i) {
|
||||
std::map<int, Value *> broadcastedDims;
|
||||
std::map<int, Value> broadcastedDims;
|
||||
auto shape = operands[i]->getType().cast<MemRefType>().getShape();
|
||||
int size = shape.size();
|
||||
for (int j = 0; j < shape.size(); ++j) {
|
||||
|
@ -192,17 +192,17 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
|||
|
||||
// Extract induction variables that are used for broadcasting values of a
|
||||
// given operand.
|
||||
std::vector<Value *>
|
||||
std::vector<Value>
|
||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<Value *> loopIVs, Value *operand,
|
||||
std::map<int, Value *> broadcastedDims) {
|
||||
ArrayRef<Value> loopIVs, Value operand,
|
||||
std::map<int, Value> broadcastedDims) {
|
||||
// `operand` must has a ranked type. This should have been checked by the
|
||||
// shape inference pass.
|
||||
auto operandShape = operand->getType().cast<MemRefType>().getShape();
|
||||
auto rank = operandShape.size();
|
||||
auto loopCount = loopIVs.size();
|
||||
|
||||
std::vector<Value *> newLoopIVs;
|
||||
std::vector<Value> newLoopIVs;
|
||||
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||
auto dimIdx = rank - 1 - reversedIdx;
|
||||
auto loopIdx = loopCount - 1 - reversedIdx;
|
||||
|
@ -247,7 +247,7 @@ struct ScalarOp<ONNXMulOp> {
|
|||
template <>
|
||||
struct ScalarOp<ONNXDivOp> {
|
||||
using FOp = DivFOp;
|
||||
using IOp = DivISOp;
|
||||
using IOp = SignedDivIOp;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
@ -295,9 +295,9 @@ using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp;
|
|||
// Scalar unary ops for lowering to Krnl dialect.
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <typename UnaryOp>
|
||||
Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
/* Lower UnaryOp to Ops in the Standard dialect.
|
||||
*/
|
||||
auto loc = op->getLoc();
|
||||
|
@ -318,14 +318,13 @@ Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
|||
// Scalar unary ops for lowering ONNXTanhOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
||||
auto loc = op->getLoc();
|
||||
Value *operand = operands[0];
|
||||
Value operand = operands[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||
|
@ -342,14 +341,13 @@ Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op,
|
|||
// Scalar unary ops for lowering ONNXSinhOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||
// ConstantOp 2)
|
||||
auto loc = op->getLoc();
|
||||
Value *operand = operands[0];
|
||||
Value operand = operands[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
||||
|
@ -366,14 +364,13 @@ Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op,
|
|||
// Scalar unary ops for lowering ONNXCoshOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||
// ConstantOp 2)
|
||||
auto loc = op->getLoc();
|
||||
Value *operand = operands[0];
|
||||
Value operand = operands[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
||||
|
@ -390,14 +387,14 @@ Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op,
|
|||
// Scalar unary ops for lowering ONNXSigmoidOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
|
||||
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
||||
auto loc = op->getLoc();
|
||||
Value *operand = operands[0];
|
||||
Value operand = operands[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||
|
@ -413,8 +410,8 @@ Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
|||
// Scalar unary ops for lowering ONNXHardSigmoidOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value *> operands,
|
||||
Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// %Y = AddFOp(MulFOp(alpha, %X), beta)
|
||||
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
|
||||
|
@ -424,7 +421,7 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
|||
// %Z,
|
||||
// Constant 1)
|
||||
auto loc = op->getLoc();
|
||||
Value *operand = operands[0];
|
||||
Value operand = operands[0];
|
||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
|
||||
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
|
||||
|
||||
|
@ -449,14 +446,14 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
|||
// Scalar unary ops for lowering ONNXEluOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
|
||||
// %X)
|
||||
auto loc = op->getLoc();
|
||||
Value *operand = operands[0];
|
||||
Value operand = operands[0];
|
||||
|
||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
|
||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||
|
@ -478,15 +475,14 @@ Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
// Scalar unary ops for lowering ONNXReluOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||
// ConstantOp 0,
|
||||
// %X)
|
||||
auto loc = op->getLoc();
|
||||
Value *operand = operands[0];
|
||||
Value operand = operands[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||
auto lessThanZero =
|
||||
|
@ -500,15 +496,15 @@ Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op,
|
|||
// Scalar unary ops for lowering ONNXLeakyReluOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *
|
||||
mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||
// MulFOp(alpha, %X),
|
||||
// %X)
|
||||
auto loc = op->getLoc();
|
||||
Value *operand = operands[0];
|
||||
Value operand = operands[0];
|
||||
|
||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
|
||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||
|
@ -525,17 +521,16 @@ mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
// Scalar unary ops for lowering ONNXSeluOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
|
||||
// MulFOp(gamma, %X),
|
||||
// MulFOp(gamma,
|
||||
// SubFOp(MulFOp(alpha, ExpOp(%X)),
|
||||
// alpha)))
|
||||
auto loc = op->getLoc();
|
||||
Value *operand = operands[0];
|
||||
Value operand = operands[0];
|
||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
|
||||
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
|
||||
|
||||
|
@ -558,13 +553,12 @@ Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op,
|
|||
// Scalar unary ops for lowering ONNXReciprocalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *
|
||||
mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXReciprocalOp>(
|
||||
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
|
||||
auto loc = op->getLoc();
|
||||
Value *operand = operands[0];
|
||||
Value operand = operands[0];
|
||||
|
||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||
auto result = rewriter.create<DivFOp>(loc, one, operand);
|
||||
|
@ -576,15 +570,15 @@ mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
// Scalar unary ops for lowering ONNXMaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
|
||||
// %X,
|
||||
// %Y)
|
||||
auto loc = op->getLoc();
|
||||
Value *lhs = operands[0];
|
||||
Value *rhs = operands[1];
|
||||
Value lhs = operands[0];
|
||||
Value rhs = operands[1];
|
||||
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
||||
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
||||
return result;
|
||||
|
@ -594,15 +588,15 @@ Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
// Scalar unary ops for lowering ONNXMinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value *mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
|
||||
// %X,
|
||||
// %Y)
|
||||
auto loc = op->getLoc();
|
||||
Value *lhs = operands[0];
|
||||
Value *rhs = operands[1];
|
||||
Value lhs = operands[0];
|
||||
Value rhs = operands[1];
|
||||
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
||||
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
||||
return result;
|
||||
|
@ -615,7 +609,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
|||
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// TODO: Check that the types are valid.
|
||||
// An element-wise unary operation must have all operands and the result of
|
||||
|
@ -632,7 +626,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
|||
// dimensions with the result at this pre-optimization phase.
|
||||
// TODO: verify that dimensions match.
|
||||
// TODO: can the dimension of the result differ after optimizations?
|
||||
Value *alloc;
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
|
@ -647,7 +641,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
|||
|
||||
// Define loops.
|
||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
||||
std::vector<Value *> originalLoops;
|
||||
std::vector<Value> originalLoops;
|
||||
originalLoops.reserve(rank);
|
||||
for (auto result : loopsOp.getResults()) {
|
||||
originalLoops.push_back(result);
|
||||
|
@ -655,7 +649,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
|||
|
||||
// Define loop optimization.
|
||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
||||
std::vector<Value *> optimizedLoops;
|
||||
std::vector<Value> optimizedLoops;
|
||||
optimizedLoops.reserve(rank);
|
||||
for (auto result : optimizedLoopsOp.getResults()) {
|
||||
optimizedLoops.push_back(result);
|
||||
|
@ -695,7 +689,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
|||
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||
|
||||
// Handle the operation:
|
||||
SmallVector<Value *, 4> loopIVs;
|
||||
SmallVector<Value, 4> loopIVs;
|
||||
for (auto arg : iterationBlock.getArguments())
|
||||
loopIVs.push_back(arg);
|
||||
|
||||
|
@ -718,7 +712,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
|||
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// TODO: Check that the types are valid.
|
||||
// An element-wise variadic operation must have all operands and the result
|
||||
|
@ -730,7 +724,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
|||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
|
||||
Value *alloc;
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
// If the output has a dynamic dimension, we compute its dimension at
|
||||
// runtime by using dimensions from the operands.
|
||||
|
@ -749,7 +743,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
|||
|
||||
// Define loops.
|
||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
||||
std::vector<Value *> originalLoops;
|
||||
std::vector<Value> originalLoops;
|
||||
originalLoops.reserve(rank);
|
||||
for (auto result : loopsOp.getResults()) {
|
||||
originalLoops.push_back(result);
|
||||
|
@ -757,7 +751,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
|||
|
||||
// Define loop optimization.
|
||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
||||
std::vector<Value *> optimizedLoops;
|
||||
std::vector<Value> optimizedLoops;
|
||||
optimizedLoops.reserve(rank);
|
||||
for (auto result : optimizedLoopsOp.getResults()) {
|
||||
optimizedLoops.push_back(result);
|
||||
|
@ -781,7 +775,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
|||
|
||||
// Get run-time dimension information for unknown dimensions used for
|
||||
// broadcasting.
|
||||
std::map<int, std::map<int, Value *>> broadcastedDimInfo =
|
||||
std::map<int, std::map<int, Value>> broadcastedDimInfo =
|
||||
getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
|
||||
|
||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||
|
@ -801,12 +795,12 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
|||
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||
|
||||
// Handle the operation:
|
||||
SmallVector<Value *, 4> loopIVs;
|
||||
SmallVector<Value, 4> loopIVs;
|
||||
for (auto arg : iterationBlock.getArguments())
|
||||
loopIVs.push_back(arg);
|
||||
|
||||
// Fold over operands for each of their scalar values
|
||||
Value *accumulated, *next;
|
||||
Value accumulated, next;
|
||||
auto accumulatedLoopIVs = getLoopIVsForBroadcasting(
|
||||
loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]);
|
||||
accumulated = rewriter.create<LoadOp>(loc, operands[0], accumulatedLoopIVs);
|
||||
|
@ -831,17 +825,17 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
Value *alloc;
|
||||
Value alloc;
|
||||
|
||||
// Compute size in bytes.
|
||||
Value *tensorSize = rewriter.create<ConstantOp>(
|
||||
Value tensorSize = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
getMemRefEltSizeInBytes(memRefType)));
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
@ -849,14 +843,14 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
} else {
|
||||
auto memRefShape = memRefType.getShape();
|
||||
SmallVector<Value *, 4> allocOperands;
|
||||
SmallVector<Value, 4> allocOperands;
|
||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||
// The shape array can always be used to construct shape information of
|
||||
// the result.
|
||||
Value *index = rewriter.create<ConstantOp>(
|
||||
Value index = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||
Value *loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||
Value *int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
||||
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||
Value int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
||||
loc, loadedVal, rewriter.getIntegerType(64));
|
||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
|
||||
allocOperands.push_back(rewriter.create<IndexCastOp>(
|
||||
|
|
|
@ -30,7 +30,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
|||
operandItr++;
|
||||
|
||||
// Organize operands into lower/upper bounds in affine.for ready formats.
|
||||
SmallVector<Value *, 4> lbOperands, ubOperands;
|
||||
SmallVector<Value, 4> lbOperands, ubOperands;
|
||||
AffineMap lbMap, ubMap;
|
||||
for (int boundType = 0; boundType < 2; boundType++) {
|
||||
auto &operands = boundType == 0 ? lbOperands : ubOperands;
|
||||
|
|
|
@ -51,7 +51,7 @@ public:
|
|||
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto *context = op->getContext();
|
||||
auto loc = op->getLoc();
|
||||
|
@ -66,27 +66,27 @@ public:
|
|||
// First operand.
|
||||
Type dstType =
|
||||
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
Value *alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
|
||||
Value *alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
||||
|
||||
// Second operand.
|
||||
Type srcType =
|
||||
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
Value *alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
|
||||
Value *alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
||||
|
||||
// Size.
|
||||
Value *int64Size = rewriter.create<LLVM::SExtOp>(
|
||||
Value int64Size = rewriter.create<LLVM::SExtOp>(
|
||||
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
||||
|
||||
// Memcpy call
|
||||
rewriter.create<CallOp>(
|
||||
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||
ArrayRef<Value *>(
|
||||
ArrayRef<Value>(
|
||||
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
|
@ -210,7 +210,7 @@ public:
|
|||
|
||||
// Retrieve dynamic mem refs from wrapped input, and convert every one of
|
||||
// them to static mem refs.
|
||||
SmallVector<Value *, 4> staticInputs;
|
||||
SmallVector<Value, 4> staticInputs;
|
||||
auto wrappedInput = entryPointEntryBlock.getArgument(0);
|
||||
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
|
||||
// Call API function to retrieve the i-th dynamic memref.
|
||||
|
@ -225,13 +225,12 @@ public:
|
|||
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
||||
auto one = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(1));
|
||||
Value *ptrToMemRef =
|
||||
rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
|
||||
/*alignment=*/0);
|
||||
Value ptrToMemRef = rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
|
||||
/*alignment=*/0);
|
||||
|
||||
// Fill in the memref underlying ptrToMemRef with information extracted
|
||||
// from dynMemRef.
|
||||
fillPtrToMemRefWithDynMemRef(*dynMemRef, *ptrToMemRef, rewriter, loc,
|
||||
fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc,
|
||||
apiRegistry, llvmDialect);
|
||||
|
||||
// ptrToMemRef will be an input to main computation graph function.
|
||||
|
@ -261,8 +260,8 @@ public:
|
|||
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
||||
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
||||
fillDynMemRefWithMemRef(*outMemRef, *outDynMemRef, rewriter, loc,
|
||||
apiRegistry, llvmDialect);
|
||||
fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry,
|
||||
llvmDialect);
|
||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(0));
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
||||
|
@ -270,7 +269,7 @@ public:
|
|||
|
||||
// Return wrapped output.
|
||||
rewriter.create<LLVM::ReturnOp>(loc,
|
||||
SmallVector<Value *, 1>({wrappedOutput}));
|
||||
SmallVector<Value, 1>({wrappedOutput}));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
|
@ -315,11 +314,11 @@ private:
|
|||
|
||||
// Call a registered API, return the return SSA values if only one result is
|
||||
// returned, otherwise return nullptr.
|
||||
Value *callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
||||
API apiId, ArrayRef<Value *> params) const {
|
||||
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
||||
API apiId, ArrayRef<Value> params) const {
|
||||
auto returnVals = rewriter.create<LLVM::CallOp>(
|
||||
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
|
||||
ArrayRef<Value *>(params));
|
||||
ArrayRef<Value>(params));
|
||||
if (returnVals.getNumResults() == 1)
|
||||
return returnVals.getResult(0);
|
||||
return nullptr;
|
||||
|
@ -348,12 +347,11 @@ private:
|
|||
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
|
||||
Value *memRef =
|
||||
rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, &ptrToMemRef);
|
||||
Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, ptrToMemRef);
|
||||
|
||||
// Set dataPtr and alignedDataPtr;
|
||||
auto dataPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {&dynMemRef});
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
|
||||
dataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, memRefTy.getStructElementType(0), dataPtr);
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
|
@ -373,9 +371,9 @@ private:
|
|||
// Get rank, sizes array ptr and strides array ptr.
|
||||
auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
|
||||
auto sizesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&dynMemRef});
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
|
||||
auto stridesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&dynMemRef});
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {dynMemRef});
|
||||
|
||||
for (decltype(rank) i = 0; i < rank; i++) {
|
||||
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||
|
@ -384,7 +382,7 @@ private:
|
|||
// Insert size of the dimension.
|
||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||
ArrayRef<Value *>({dimIdx}));
|
||||
ArrayRef<Value>({dimIdx}));
|
||||
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
|
||||
dimSizePtr);
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
|
@ -395,7 +393,7 @@ private:
|
|||
// Insert stride of the dimension.
|
||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||
ArrayRef<Value *>({dimIdx}));
|
||||
ArrayRef<Value>({dimIdx}));
|
||||
auto dimStride = rewriter.create<LLVM::LoadOp>(
|
||||
loc, int64Ty.getPointerTo(), dimStridePtr);
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
|
@ -404,7 +402,7 @@ private:
|
|||
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||
}
|
||||
|
||||
rewriter.create<LLVM::StoreOp>(loc, memRef, &ptrToMemRef);
|
||||
rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
|
||||
}
|
||||
|
||||
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
|
||||
|
@ -415,19 +413,19 @@ private:
|
|||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
|
||||
// Extract the data pointer, and record it in dynamic mem ref created.
|
||||
Value *outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, outMemRefTy.getStructElementType(0), &outMemRef,
|
||||
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, outMemRefTy.getStructElementType(0), outMemRef,
|
||||
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
|
||||
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
||||
{&outDynMemRef, outMemRefDataPtr});
|
||||
{outDynMemRef, outMemRefDataPtr});
|
||||
|
||||
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
|
||||
auto sizesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&outDynMemRef});
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
|
||||
auto stridesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&outDynMemRef});
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outDynMemRef});
|
||||
|
||||
for (decltype(rank) i = 0; i < rank; i++) {
|
||||
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||
|
@ -435,22 +433,22 @@ private:
|
|||
|
||||
// Transfer size of dimension from memref to dynamic memref.
|
||||
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, int64Ty, &outMemRef,
|
||||
loc, int64Ty, outMemRef,
|
||||
rewriter.getArrayAttr(
|
||||
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||
ArrayRef<Value *>({dimIdx}));
|
||||
ArrayRef<Value>({dimIdx}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
|
||||
|
||||
// Transfer stride of dimension from memref to dynamic memref.
|
||||
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, int64Ty, &outMemRef,
|
||||
loc, int64Ty, outMemRef,
|
||||
rewriter.getArrayAttr(
|
||||
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), stridesArrayPtr,
|
||||
ArrayRef<Value *>({dimIdx}));
|
||||
ArrayRef<Value>({dimIdx}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
set(LLVM_LIT ${LLVM_SRC}/utils/lit/lit.py)
|
||||
set(LLVM_DEFAULT_EXTERNAL_LIT ${LLVM_BUILD}/bin/llvm-lit)
|
||||
set(LLVM_LIT ${LLVM_PROJ_SRC}/utils/lit/lit.py)
|
||||
set(LLVM_DEFAULT_EXTERNAL_LIT ${LLVM_PROJ_BUILD}/bin/llvm-lit)
|
||||
|
||||
configure_lit_site_cfg(${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import lit.llvm
|
||||
|
||||
config.llvm_tools_dir = "@MLIR_TOOLS_DIR@"
|
||||
config.mlir_obj_root = "@LLVM_BUILD@"
|
||||
config.mlir_obj_root = "@LLVM_PROJ_BUILD@"
|
||||
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
|
||||
config.suffixes = ['.mlir']
|
||||
|
||||
|
|
Loading…
Reference in New Issue