Merge pull request #9 from clang-ykt/use-mlir-in-llvm-project

Update to latest MLIR
This commit is contained in:
Tian Jin 2020-01-06 16:14:27 -05:00 committed by GitHub
commit 322002f509
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 214 additions and 227 deletions

View File

@ -7,52 +7,33 @@ jobs:
steps:
- checkout
- run:
name: "Pull Submodules"
name: Pull Submodules
command: |
git submodule update --init --recursive
- run:
name: Check current directory
command: pwd
- run:
name: Check current directory content
command: ls
- run:
name: Installing GCC
command: 'sudo apt-get update && sudo apt-get install -y gcc g++'
- run:
name: Install CMAKE
command: 'sudo apt-get update && sudo apt-get install -y cmake ninja-build'
- run:
name: Install Protobuf
command: 'sudo apt-get update && sudo apt-get install -y protobuf-compiler'
- run:
name: Check gcc version
command: gcc --version
name: Installing GCC, CMake, Ninja, Protobuf
command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler
# Use cached mlir installation if possible.
- restore_cache:
key: ONNF-MLIR-{{ arch }}
key: V2-LLVM-PROJECT-{{ arch }}
- run:
name: Install MLIR
command: |
# Check whether cache restoration succeeds by checking whether
# mlir-opt executable exists.
if [ ! -f llvm-project/build/bin/mlir-opt ]; then
git clone https://github.com/llvm/llvm-project.git
cd llvm-project && git checkout 9b6ad8466bb8b97082b705270603ad7f4559e931 && cd ..
git clone https://github.com/tensorflow/mlir llvm-project/llvm/projects/mlir
cd llvm-project/llvm/projects/mlir && git checkout 0710266d0f56cf6ab0f437badbd7416b6cecdf5f && cd ../../../..
mkdir llvm-project/build
cd llvm-project/build
cmake -G Ninja ../llvm -DLLVM_ENABLE_RTTI=ON -DLLVM_BUILD_EXAMPLES=OFF -DLLVM_TARGETS_TO_BUILD="host" -DCMAKE_BUILD_TYPE=Release
CMAKE_EXE_LINKER_FLAGS="-Wl,--reduce-memory-overheads -Wl,--hash-size=512" cmake --build . --target check-mlir -- -j 4
export MAKEFLAGS=-j4
source .circleci/install-mlir.sh
fi
- save_cache:
key: ONNF-MLIR-{{ arch }}
key: V2-LLVM-PROJECT-{{ arch }}
paths:
- llvm-project
- run:
name: Install ONNF
command: |
mkdir build && cd build
LLVM_SRC=$(pwd)/../llvm-project/llvm LLVM_BUILD=$(pwd)/../llvm-project/build cmake ..
LLVM_PROJ_SRC=$(pwd)/../llvm-project/ LLVM_PROJ_BUILD=$(pwd)/../llvm-project/build cmake ..
make all
LIT_OPTS=-v make check-mlir-lit
- run:

12
.circleci/install-mlir.sh Normal file
View File

@ -0,0 +1,12 @@
git clone https://github.com/llvm/llvm-project.git
mkdir llvm-project/build
cd llvm-project/build
cmake -G Ninja ../llvm \
-DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_BUILD_EXAMPLES=ON \
-DLLVM_TARGETS_TO_BUILD="host" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON
cmake --build . --target check-mlir -- ${MAKEFLAGS}

View File

@ -1,38 +1,38 @@
# Path to LLVM source folder.
if(DEFINED ENV{LLVM_SRC})
set(LLVM_SRC $ENV{LLVM_SRC})
if(EXISTS ${LLVM_SRC})
message(STATUS "LLVM_SRC " ${LLVM_SRC})
if(DEFINED ENV{LLVM_PROJ_SRC})
set(LLVM_PROJ_SRC $ENV{LLVM_PROJ_SRC})
if(EXISTS ${LLVM_PROJ_SRC})
message(STATUS "LLVM_PROJ_SRC " ${LLVM_PROJ_SRC})
else()
message(FATAL_ERROR "The path specified by LLVM_SRC does not exist: "
${LLVM_SRC})
message(FATAL_ERROR "The path specified by LLVM_PROJ_SRC does not exist: "
${LLVM_PROJ_SRC})
endif()
else()
message(FATAL_ERROR "env variable LLVM_SRC not set")
message(FATAL_ERROR "env variable LLVM_PROJ_SRC not set")
endif()
# Path to LLVM build folder
if(DEFINED ENV{LLVM_BUILD})
set(LLVM_BUILD $ENV{LLVM_BUILD})
if(EXISTS ${LLVM_BUILD})
message(STATUS "LLVM_BUILD " ${LLVM_BUILD})
if(DEFINED ENV{LLVM_PROJ_BUILD})
set(LLVM_PROJ_BUILD $ENV{LLVM_PROJ_BUILD})
if(EXISTS ${LLVM_PROJ_BUILD})
message(STATUS "LLVM_PROJ_BUILD " ${LLVM_PROJ_BUILD})
else()
message(FATAL_ERROR "The path specified by LLVM_BUILD does not exist: "
${LLVM_BUILD})
message(FATAL_ERROR "The path specified by LLVM_PROJ_BUILD does not exist: "
${LLVM_PROJ_BUILD})
endif()
else()
message(FATAL_ERROR "env variable LLVM_BUILD not set")
message(FATAL_ERROR "env variable LLVM_PROJ_BUILD not set")
endif()
# LLVM project lib folder
set(LLVM_PROJECT_LIB ${LLVM_BUILD}/lib)
set(LLVM_PROJECT_LIB ${LLVM_PROJ_BUILD}/lib)
# Include paths for MLIR
set(LLVM_SRC_INCLUDE_PATH ${LLVM_SRC}/include)
set(LLVM_BIN_INCLUDE_PATH ${LLVM_BUILD}/include)
set(MLIR_SRC_INCLUDE_PATH ${LLVM_SRC}/projects/mlir/include)
set(MLIR_BIN_INCLUDE_PATH ${LLVM_BUILD}/projects/mlir/include)
set(MLIR_TOOLS_DIR ${LLVM_BUILD}/bin)
set(LLVM_SRC_INCLUDE_PATH ${LLVM_PROJ_SRC}/llvm/include)
set(LLVM_BIN_INCLUDE_PATH ${LLVM_PROJ_BUILD}/include)
set(MLIR_SRC_INCLUDE_PATH ${LLVM_PROJ_SRC}/mlir/include)
set(MLIR_BIN_INCLUDE_PATH ${LLVM_PROJ_BUILD}/tools/mlir/include)
set(MLIR_TOOLS_DIR ${LLVM_PROJ_BUILD}/bin)
set(ONNF_TOOLS_DIR ${ONNF_BIN_ROOT}/bin)
set(ONNF_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
@ -173,7 +173,7 @@ function(whole_archive_link target lib_dir)
endfunction(whole_archive_link)
function(whole_archive_link_mlir target)
whole_archive_link(${target} ${LLVM_BUILD}/lib ${ARGN})
whole_archive_link(${target} ${LLVM_PROJ_BUILD}/lib ${ARGN})
endfunction(whole_archive_link_mlir)
function(whole_archive_link_onnf target)
@ -184,7 +184,7 @@ function(whole_archive_link_onnf target)
endfunction(whole_archive_link_onnf)
set(LLVM_CMAKE_DIR
"${LLVM_BUILD}/lib/cmake/llvm"
"${LLVM_PROJ_BUILD}/lib/cmake/llvm"
CACHE PATH "Path to LLVM cmake modules")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(AddLLVM)
@ -205,5 +205,5 @@ endfunction()
# table gen utility itself can be detected and cause re-compilation of .td file.
add_executable(mlir-tblgen IMPORTED)
set_property(TARGET mlir-tblgen
PROPERTY IMPORTED_LOCATION ${LLVM_BUILD}/bin/mlir-tblgen)
PROPERTY IMPORTED_LOCATION ${LLVM_PROJ_BUILD}/bin/mlir-tblgen)
set(MLIR_TABLEGEN_EXE mlir-tblgen)

View File

@ -71,7 +71,7 @@ struct OnnxOnnfSymbolMapping {
* @param name onnx tensor name.
* @return onnf tensor corresponding to `name`.
*/
mlir::Value *GetTensorByOnnxName(std::string name) {
mlir::Value GetTensorByOnnxName(std::string name) {
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
onnx_name2onnf_tensor.end() &&
"Tensor not found");
@ -81,9 +81,9 @@ struct OnnxOnnfSymbolMapping {
/*!
* Add a new mapping from onnx tensor name to MLIR symbol.
* @param name onnx tensor name.
* @param tensor MLIR Value* pointer.
* @param tensor MLIR Value pointer.
*/
void AddMapping(std::string name, mlir::Value *tensor) {
void AddMapping(std::string name, mlir::Value tensor) {
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
"Tensor already exists.");
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
@ -97,7 +97,7 @@ private:
/*!
* mapping from onnx tensor names to MLIR tensor.
*/
std::map<std::string, mlir::Value*> onnx_name2onnf_tensor;
std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
};
class FrontendGenImpl {
@ -192,13 +192,13 @@ private:
/*!
* Import a input tensor symbol by recording a new entry in frontend_symbols_
* recording the mapping between legalized onnx tensor name and mlir::Value*
* recording the mapping between legalized onnx tensor name and mlir::Value
* for further lookup in computation node importing.
* @param input onnx input tensor ValueInfoProto.
* @param symbol mlir input argument.
*/
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
mlir::Value *symbol) {
mlir::Value symbol) {
auto input_tensor_legalized_name = legalize_name(input.name());
assert(
!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
@ -480,7 +480,7 @@ private:
}
void ImportNodeGeneric(onnx::NodeProto node) {
std::vector<mlir::Value *> inputs;
std::vector<mlir::Value> inputs;
for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
@ -515,7 +515,7 @@ private:
onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) {
std::vector<mlir::Value *> inputs;
std::vector<mlir::Value> inputs;
for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
@ -562,7 +562,7 @@ private:
onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) {
std::vector<mlir::Value *> inputs;
std::vector<mlir::Value> inputs;
for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
@ -633,7 +633,7 @@ private:
}
void ImportNode(onnx::NodeProto node) {
std::vector<mlir::Value *> inputs;
std::vector<mlir::Value> inputs;
for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
@ -662,17 +662,17 @@ private:
* Import output tensor, by doing the following:
* - Add the type of this output tensor to a list of tensor
* types representing return types of this graph function.
* - Add this output tensor to the list of mlir::Value*
* - Add this output tensor to the list of mlir::Value
* to be returned by the function representing computation graph.
* @param output onnx output tensor ValueInfoProto.
* @param ret_types a vector of tensor types representing graph's
* output tensor types.
* @param ret_vals a vector of mlir Value* representing graph's
* @param ret_vals a vector of mlir Value representing graph's
* output tensor.
*/
void ImportOutputTensor(const onnx::ValueInfoProto &output,
llvm::SmallVectorImpl<mlir::Type> &ret_types,
llvm::SmallVectorImpl<mlir::Value *> &ret_vals) {
llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
auto output_tensor_legalized_name = legalize_name(output.name());
assert(
frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
@ -722,7 +722,7 @@ private:
}
llvm::SmallVector<mlir::Type, 4> ret_types;
llvm::SmallVector<mlir::Value *, 4> ret_vals;
llvm::SmallVector<mlir::Value, 4> ret_vals;
// Import the output tensors
for (const auto &output : graph.output()) {
ImportOutputTensor(output, ret_types, ret_vals);

View File

@ -9,8 +9,9 @@ namespace onnf {
using namespace mlir;
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
const Type& operandType, Value*& operand) {
ParseResult
KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
Value &operand) {
// If operand queue is empty, parse more operands and cache them.
if (_operandRefQueue.empty()) {
// Parse operand types:
@ -27,7 +28,7 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
auto operand_ref = _operandRefQueue.front();
_operandRefQueue.pop();
llvm::SmallVector<Value*, 1> operands;
llvm::SmallVector<Value, 1> operands;
_parser.resolveOperand(operand_ref, operandType, operands);
operand = operands.front();
return success();
@ -38,8 +39,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
}
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
Value* operand = nullptr;
const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) {
Value operand = nullptr;
if (ParseOptionalOperand(operandType, operand))
return failure();
@ -47,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
return success();
}
ParseResult KrnlDialectOperandParser::ParseOperand(
const Type& operandType, Value*& operand) {
ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType,
Value &operand) {
if (ParseOptionalOperand(operandType, operand))
return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand.");
@ -56,7 +57,7 @@ ParseResult KrnlDialectOperandParser::ParseOperand(
}
ParseResult KrnlDialectOperandParser::ParseOperand(
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) {
if (ParseOptionalOperand(operandType, operandList))
return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand.");
@ -129,7 +130,7 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
boundMaps.emplace_back(AffineMapAttr::get(map));
}
void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) {
void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
if (boundMaps.size() % 2 == 0)
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
AffineMap map = builder.getSymbolIdentityMap();

View File

@ -17,20 +17,22 @@ class KrnlDialectOperandParser {
: _parser(parser), _builder(parser.getBuilder()){};
// Parse an optional operand.
mlir::ParseResult ParseOptionalOperand(
const mlir::Type& operandType, mlir::Value*& operand);
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
mlir::Value &operand);
// Parse an optional operand and push it to an operand list.
mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType,
llvm::SmallVectorImpl<mlir::Value*>& operandList);
mlir::ParseResult
ParseOptionalOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList);
// Parse a required operand.
mlir::ParseResult ParseOperand(
const mlir::Type& operandType, mlir::Value*& operand);
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
mlir::Value &operand);
// Parse a required operand and push it to an operand list.
mlir::ParseResult ParseOperand(const mlir::Type& operandType,
llvm::SmallVectorImpl<mlir::Value*>& operandList);
mlir::ParseResult
ParseOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList);
// Do we have more operands to parse?
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
@ -63,11 +65,10 @@ void printBound(mlir::AffineMapAttr boundMap,
namespace mlir {
struct KrnlIterateOperandPack {
KrnlIterateOperandPack(mlir::Builder& builder,
llvm::ArrayRef<mlir::Value*> inputLoops,
llvm::ArrayRef<mlir::Value*> optimizedLoops)
: builder(builder),
inputLoops(inputLoops),
KrnlIterateOperandPack(mlir::Builder &builder,
llvm::ArrayRef<mlir::Value> inputLoops,
llvm::ArrayRef<mlir::Value> optimizedLoops)
: builder(builder), inputLoops(inputLoops),
optimizedLoops(optimizedLoops) {
_operands.insert(
_operands.end(), optimizedLoops.begin(), optimizedLoops.end());
@ -75,9 +76,9 @@ struct KrnlIterateOperandPack {
void pushConstantBound(int64_t bound);
void pushOperandBound(mlir::Value* operand);
void pushOperandBound(mlir::Value operand);
llvm::SmallVector<mlir::Value*, 8> getOperands() const { return _operands; }
llvm::SmallVector<mlir::Value, 8> getOperands() const { return _operands; }
mlir::ArrayAttr getAttributes() const {
return builder.getArrayAttr(boundMaps);
@ -90,11 +91,11 @@ struct KrnlIterateOperandPack {
private:
int _boundIdx = 0;
llvm::SmallVector<mlir::Value*, 8> _operands;
llvm::SmallVector<mlir::Value, 8> _operands;
llvm::SmallVector<mlir::Attribute, 8> boundMaps;
llvm::ArrayRef<mlir::Value*> inputLoops, optimizedLoops;
llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
mlir::Builder& builder;
};

View File

@ -44,21 +44,21 @@ static MemRefType convertTensorToMemRef(TensorType type) {
}
/// Insert an allocation and deallocation for the given MemRefType.
static Value *insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter,
bool insertDealloc,
ArrayRef<Value *> operands = {}) {
static Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter,
bool insertDealloc,
ArrayRef<Value> operands = {}) {
// Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc;
if (!operands.empty()) {
auto memRefShape = type.getShape();
auto rank = memRefShape.size();
std::map<int, Value *> fromOperands;
std::map<int, Value> fromOperands;
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
int memRefDimIdx = rank - 1 - reversedIdx;
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
Value *maxDim = nullptr;
Value maxDim = nullptr;
for (int i = 0; i < operands.size(); i++) {
auto operandShape =
operands[i]->getType().cast<MemRefType>().getShape();
@ -85,7 +85,7 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc,
}
}
SmallVector<Value *, 4> allocOperands;
SmallVector<Value, 4> allocOperands;
for (int i = 0; i < rank; ++i)
if (memRefShape[i] < 0)
allocOperands.push_back(fromOperands[i]);
@ -146,14 +146,14 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
// Get run-time dimension information for unknown dimensions used for
// broadcasting.
std::map<int, std::map<int, Value *>>
std::map<int, std::map<int, Value>>
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
MemRefType memRefType, ArrayRef<Value *> operands) {
MemRefType memRefType, ArrayRef<Value> operands) {
auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size();
// For unknown dimensions, we need to get dimension values at runtime in
// order to do broadcasting.
std::map<int, std::map<int, Value *>> DimInfo;
std::map<int, std::map<int, Value>> DimInfo;
// For each result dimension, compute the number of sharing operands.
// Sharing operands are operands sharing the same index (counting from the
// rightmost to the leftmost) for a given dimension.
@ -173,7 +173,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
// We only care about unknown dimensions whose number of sharing operands is
// more than one, since they are potentially broadcasted dimensions.
for (int i = 0; i < operands.size(); ++i) {
std::map<int, Value *> broadcastedDims;
std::map<int, Value> broadcastedDims;
auto shape = operands[i]->getType().cast<MemRefType>().getShape();
int size = shape.size();
for (int j = 0; j < shape.size(); ++j) {
@ -192,17 +192,17 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
// Extract induction variables that are used for broadcasting values of a
// given operand.
std::vector<Value *>
std::vector<Value>
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value *> loopIVs, Value *operand,
std::map<int, Value *> broadcastedDims) {
ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value> broadcastedDims) {
// `operand` must has a ranked type. This should have been checked by the
// shape inference pass.
auto operandShape = operand->getType().cast<MemRefType>().getShape();
auto rank = operandShape.size();
auto loopCount = loopIVs.size();
std::vector<Value *> newLoopIVs;
std::vector<Value> newLoopIVs;
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
auto dimIdx = rank - 1 - reversedIdx;
auto loopIdx = loopCount - 1 - reversedIdx;
@ -247,7 +247,7 @@ struct ScalarOp<ONNXMulOp> {
template <>
struct ScalarOp<ONNXDivOp> {
using FOp = DivFOp;
using IOp = DivISOp;
using IOp = SignedDivIOp;
};
template <>
@ -295,9 +295,9 @@ using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp;
// Scalar unary ops for lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
template <typename UnaryOp>
Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
/* Lower UnaryOp to Ops in the Standard dialect.
*/
auto loc = op->getLoc();
@ -318,14 +318,13 @@ Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXTanhOp
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
auto loc = op->getLoc();
Value *operand = operands[0];
Value operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
@ -342,14 +341,13 @@ Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op,
// Scalar unary ops for lowering ONNXSinhOp
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2)
auto loc = op->getLoc();
Value *operand = operands[0];
Value operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
@ -366,14 +364,13 @@ Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op,
// Scalar unary ops for lowering ONNXCoshOp
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2)
auto loc = op->getLoc();
Value *operand = operands[0];
Value operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
@ -390,14 +387,14 @@ Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op,
// Scalar unary ops for lowering ONNXSigmoidOp
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
auto loc = op->getLoc();
Value *operand = operands[0];
Value operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
@ -413,8 +410,8 @@ Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
// Scalar unary ops for lowering ONNXHardSigmoidOp
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value *> operands,
Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// %Y = AddFOp(MulFOp(alpha, %X), beta)
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
@ -424,7 +421,7 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
// %Z,
// Constant 1)
auto loc = op->getLoc();
Value *operand = operands[0];
Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
@ -449,14 +446,14 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
// Scalar unary ops for lowering ONNXEluOp
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
// %X)
auto loc = op->getLoc();
Value *operand = operands[0];
Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
@ -478,15 +475,14 @@ Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXReluOp
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// ConstantOp 0,
// %X)
auto loc = op->getLoc();
Value *operand = operands[0];
Value operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto lessThanZero =
@ -500,15 +496,15 @@ Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op,
// Scalar unary ops for lowering ONNXLeakyReluOp
//===----------------------------------------------------------------------===//
template <>
Value *
mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, %X),
// %X)
auto loc = op->getLoc();
Value *operand = operands[0];
Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
@ -525,17 +521,16 @@ mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXSeluOp
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
// MulFOp(gamma, %X),
// MulFOp(gamma,
// SubFOp(MulFOp(alpha, ExpOp(%X)),
// alpha)))
auto loc = op->getLoc();
Value *operand = operands[0];
Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
@ -558,13 +553,12 @@ Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op,
// Scalar unary ops for lowering ONNXReciprocalOp
//===----------------------------------------------------------------------===//
template <>
Value *
mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp<ONNXReciprocalOp>(
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
auto loc = op->getLoc();
Value *operand = operands[0];
Value operand = operands[0];
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto result = rewriter.create<DivFOp>(loc, one, operand);
@ -576,15 +570,15 @@ mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXMaxOp
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
// %X,
// %Y)
auto loc = op->getLoc();
Value *lhs = operands[0];
Value *rhs = operands[1];
Value lhs = operands[0];
Value rhs = operands[1];
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
return result;
@ -594,15 +588,15 @@ Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXMinOp
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
// %X,
// %Y)
auto loc = op->getLoc();
Value *lhs = operands[0];
Value *rhs = operands[1];
Value lhs = operands[0];
Value rhs = operands[1];
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
return result;
@ -615,7 +609,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid.
// An element-wise unary operation must have all operands and the result of
@ -632,7 +626,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// dimensions with the result at this pre-optimization phase.
// TODO: verify that dimensions match.
// TODO: can the dimension of the result differ after optimizations?
Value *alloc;
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
@ -647,7 +641,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value *> originalLoops;
std::vector<Value> originalLoops;
originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result);
@ -655,7 +649,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value *> optimizedLoops;
std::vector<Value> optimizedLoops;
optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result);
@ -695,7 +689,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation:
SmallVector<Value *, 4> loopIVs;
SmallVector<Value, 4> loopIVs;
for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg);
@ -718,7 +712,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid.
// An element-wise variadic operation must have all operands and the result
@ -730,7 +724,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
Value *alloc;
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
// If the output has a dynamic dimension, we compute its dimension at
// runtime by using dimensions from the operands.
@ -749,7 +743,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value *> originalLoops;
std::vector<Value> originalLoops;
originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result);
@ -757,7 +751,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value *> optimizedLoops;
std::vector<Value> optimizedLoops;
optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result);
@ -781,7 +775,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Get run-time dimension information for unknown dimensions used for
// broadcasting.
std::map<int, std::map<int, Value *>> broadcastedDimInfo =
std::map<int, std::map<int, Value>> broadcastedDimInfo =
getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
@ -801,12 +795,12 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation:
SmallVector<Value *, 4> loopIVs;
SmallVector<Value, 4> loopIVs;
for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg);
// Fold over operands for each of their scalar values
Value *accumulated, *next;
Value accumulated, next;
auto accumulatedLoopIVs = getLoopIVsForBroadcasting(
loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]);
accumulated = rewriter.create<LoadOp>(loc, operands[0], accumulatedLoopIVs);
@ -831,17 +825,17 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
Value *alloc;
Value alloc;
// Compute size in bytes.
Value *tensorSize = rewriter.create<ConstantOp>(
Value tensorSize = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
getMemRefEltSizeInBytes(memRefType)));
bool insertDealloc = checkInsertDealloc(op);
@ -849,14 +843,14 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
} else {
auto memRefShape = memRefType.getShape();
SmallVector<Value *, 4> allocOperands;
SmallVector<Value, 4> allocOperands;
for (int i = 0; i < memRefShape.size(); ++i) {
// The shape array can always be used to construct shape information of
// the result.
Value *index = rewriter.create<ConstantOp>(
Value index = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
Value *loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
Value *int64LoadedVal = rewriter.create<ZeroExtendIOp>(
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
Value int64LoadedVal = rewriter.create<ZeroExtendIOp>(
loc, loadedVal, rewriter.getIntegerType(64));
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
allocOperands.push_back(rewriter.create<IndexCastOp>(

View File

@ -30,7 +30,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
operandItr++;
// Organize operands into lower/upper bounds in affine.for ready formats.
SmallVector<Value *, 4> lbOperands, ubOperands;
SmallVector<Value, 4> lbOperands, ubOperands;
AffineMap lbMap, ubMap;
for (int boundType = 0; boundType < 2; boundType++) {
auto &operands = boundType == 0 ? lbOperands : ubOperands;

View File

@ -51,7 +51,7 @@ public:
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
auto loc = op->getLoc();
@ -66,27 +66,27 @@ public:
// First operand.
Type dstType =
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value *alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
Value *alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
// Second operand.
Type srcType =
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value *alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
Value *alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
// Size.
Value *int64Size = rewriter.create<LLVM::SExtOp>(
Value int64Size = rewriter.create<LLVM::SExtOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
// Memcpy call
rewriter.create<CallOp>(
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value *>(
ArrayRef<Value>(
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
rewriter.eraseOp(op);
@ -210,7 +210,7 @@ public:
// Retrieve dynamic mem refs from wrapped input, and convert every one of
// them to static mem refs.
SmallVector<Value *, 4> staticInputs;
SmallVector<Value, 4> staticInputs;
auto wrappedInput = entryPointEntryBlock.getArgument(0);
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
// Call API function to retrieve the i-th dynamic memref.
@ -225,13 +225,12 @@ public:
auto memRefTy = memRefPtrTy.getPointerElementTy();
auto one = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(1));
Value *ptrToMemRef =
rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
/*alignment=*/0);
Value ptrToMemRef = rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
/*alignment=*/0);
// Fill in the memref underlying ptrToMemRef with information extracted
// from dynMemRef.
fillPtrToMemRefWithDynMemRef(*dynMemRef, *ptrToMemRef, rewriter, loc,
fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc,
apiRegistry, llvmDialect);
// ptrToMemRef will be an input to main computation graph function.
@ -261,8 +260,8 @@ public:
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
fillDynMemRefWithMemRef(*outMemRef, *outDynMemRef, rewriter, loc,
apiRegistry, llvmDialect);
fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry,
llvmDialect);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(0));
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
@ -270,7 +269,7 @@ public:
// Return wrapped output.
rewriter.create<LLVM::ReturnOp>(loc,
SmallVector<Value *, 1>({wrappedOutput}));
SmallVector<Value, 1>({wrappedOutput}));
return matchSuccess();
}
@ -315,11 +314,11 @@ private:
// Call a registered API, return the return SSA values if only one result is
// returned, otherwise return nullptr.
Value *callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
API apiId, ArrayRef<Value *> params) const {
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
API apiId, ArrayRef<Value> params) const {
auto returnVals = rewriter.create<LLVM::CallOp>(
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
ArrayRef<Value *>(params));
ArrayRef<Value>(params));
if (returnVals.getNumResults() == 1)
return returnVals.getResult(0);
return nullptr;
@ -348,12 +347,11 @@ private:
auto memRefTy = memRefPtrTy.getPointerElementTy();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
Value *memRef =
rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, &ptrToMemRef);
Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, ptrToMemRef);
// Set dataPtr and alignedDataPtr;
auto dataPtr =
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {&dynMemRef});
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
dataPtr = rewriter.create<LLVM::BitcastOp>(
loc, memRefTy.getStructElementType(0), dataPtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
@ -373,9 +371,9 @@ private:
// Get rank, sizes array ptr and strides array ptr.
auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&dynMemRef});
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&dynMemRef});
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {dynMemRef});
for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
@ -384,7 +382,7 @@ private:
// Insert size of the dimension.
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx}));
ArrayRef<Value>({dimIdx}));
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
dimSizePtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
@ -395,7 +393,7 @@ private:
// Insert stride of the dimension.
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx}));
ArrayRef<Value>({dimIdx}));
auto dimStride = rewriter.create<LLVM::LoadOp>(
loc, int64Ty.getPointerTo(), dimStridePtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
@ -404,7 +402,7 @@ private:
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
}
rewriter.create<LLVM::StoreOp>(loc, memRef, &ptrToMemRef);
rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
}
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
@ -415,19 +413,19 @@ private:
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
// Extract the data pointer, and record it in dynamic mem ref created.
Value *outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, outMemRefTy.getStructElementType(0), &outMemRef,
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, outMemRefTy.getStructElementType(0), outMemRef,
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
{&outDynMemRef, outMemRefDataPtr});
{outDynMemRef, outMemRefDataPtr});
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&outDynMemRef});
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&outDynMemRef});
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outDynMemRef});
for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
@ -435,22 +433,22 @@ private:
// Transfer size of dimension from memref to dynamic memref.
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, &outMemRef,
loc, int64Ty, outMemRef,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value *>({dimIdx}));
ArrayRef<Value>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
// Transfer stride of dimension from memref to dynamic memref.
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, &outMemRef,
loc, int64Ty, outMemRef,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), stridesArrayPtr,
ArrayRef<Value *>({dimIdx}));
ArrayRef<Value>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
}
}

View File

@ -1,5 +1,5 @@
set(LLVM_LIT ${LLVM_SRC}/utils/lit/lit.py)
set(LLVM_DEFAULT_EXTERNAL_LIT ${LLVM_BUILD}/bin/llvm-lit)
set(LLVM_LIT ${LLVM_PROJ_SRC}/utils/lit/lit.py)
set(LLVM_DEFAULT_EXTERNAL_LIT ${LLVM_PROJ_BUILD}/bin/llvm-lit)
configure_lit_site_cfg(${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py

View File

@ -2,7 +2,7 @@
import lit.llvm
config.llvm_tools_dir = "@MLIR_TOOLS_DIR@"
config.mlir_obj_root = "@LLVM_BUILD@"
config.mlir_obj_root = "@LLVM_PROJ_BUILD@"
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
config.suffixes = ['.mlir']