Transition to value-typed Value, rename Value* -> Value, in accordance with upstream MLIR style change.
This commit is contained in:
parent
33f988e18a
commit
0582846864
|
@ -71,7 +71,7 @@ struct OnnxOnnfSymbolMapping {
|
||||||
* @param name onnx tensor name.
|
* @param name onnx tensor name.
|
||||||
* @return onnf tensor corresponding to `name`.
|
* @return onnf tensor corresponding to `name`.
|
||||||
*/
|
*/
|
||||||
mlir::Value *GetTensorByOnnxName(std::string name) {
|
mlir::Value GetTensorByOnnxName(std::string name) {
|
||||||
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
|
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
|
||||||
onnx_name2onnf_tensor.end() &&
|
onnx_name2onnf_tensor.end() &&
|
||||||
"Tensor not found");
|
"Tensor not found");
|
||||||
|
@ -81,9 +81,9 @@ struct OnnxOnnfSymbolMapping {
|
||||||
/*!
|
/*!
|
||||||
* Add a new mapping from onnx tensor name to MLIR symbol.
|
* Add a new mapping from onnx tensor name to MLIR symbol.
|
||||||
* @param name onnx tensor name.
|
* @param name onnx tensor name.
|
||||||
* @param tensor MLIR Value* pointer.
|
* @param tensor MLIR Value pointer.
|
||||||
*/
|
*/
|
||||||
void AddMapping(std::string name, mlir::Value *tensor) {
|
void AddMapping(std::string name, mlir::Value tensor) {
|
||||||
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
|
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
|
||||||
"Tensor already exists.");
|
"Tensor already exists.");
|
||||||
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
||||||
|
@ -97,7 +97,7 @@ private:
|
||||||
/*!
|
/*!
|
||||||
* mapping from onnx tensor names to MLIR tensor.
|
* mapping from onnx tensor names to MLIR tensor.
|
||||||
*/
|
*/
|
||||||
std::map<std::string, mlir::Value*> onnx_name2onnf_tensor;
|
std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
class FrontendGenImpl {
|
class FrontendGenImpl {
|
||||||
|
@ -192,13 +192,13 @@ private:
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Import a input tensor symbol by recording a new entry in frontend_symbols_
|
* Import a input tensor symbol by recording a new entry in frontend_symbols_
|
||||||
* recording the mapping between legalized onnx tensor name and mlir::Value*
|
* recording the mapping between legalized onnx tensor name and mlir::Value
|
||||||
* for further lookup in computation node importing.
|
* for further lookup in computation node importing.
|
||||||
* @param input onnx input tensor ValueInfoProto.
|
* @param input onnx input tensor ValueInfoProto.
|
||||||
* @param symbol mlir input argument.
|
* @param symbol mlir input argument.
|
||||||
*/
|
*/
|
||||||
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
|
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
|
||||||
mlir::Value *symbol) {
|
mlir::Value symbol) {
|
||||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||||
assert(
|
assert(
|
||||||
!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
||||||
|
@ -480,7 +480,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportNodeGeneric(onnx::NodeProto node) {
|
void ImportNodeGeneric(onnx::NodeProto node) {
|
||||||
std::vector<mlir::Value *> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (auto item : node.input()) {
|
for (auto item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
|
@ -515,7 +515,7 @@ private:
|
||||||
onnx::NodeProto node, int nIn, int nOut,
|
onnx::NodeProto node, int nIn, int nOut,
|
||||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||||
attrs) {
|
attrs) {
|
||||||
std::vector<mlir::Value *> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (auto item : node.input()) {
|
for (auto item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
|
@ -562,7 +562,7 @@ private:
|
||||||
onnx::NodeProto node, int nIn, int nOut,
|
onnx::NodeProto node, int nIn, int nOut,
|
||||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||||
attrs) {
|
attrs) {
|
||||||
std::vector<mlir::Value *> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (auto item : node.input()) {
|
for (auto item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
|
@ -633,7 +633,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportNode(onnx::NodeProto node) {
|
void ImportNode(onnx::NodeProto node) {
|
||||||
std::vector<mlir::Value *> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (auto item : node.input()) {
|
for (auto item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
|
@ -662,17 +662,17 @@ private:
|
||||||
* Import output tensor, by doing the following:
|
* Import output tensor, by doing the following:
|
||||||
* - Add the type of this output tensor to a list of tensor
|
* - Add the type of this output tensor to a list of tensor
|
||||||
* types representing return types of this graph function.
|
* types representing return types of this graph function.
|
||||||
* - Add this output tensor to the list of mlir::Value*
|
* - Add this output tensor to the list of mlir::Value
|
||||||
* to be returned by the function representing computation graph.
|
* to be returned by the function representing computation graph.
|
||||||
* @param output onnx output tensor ValueInfoProto.
|
* @param output onnx output tensor ValueInfoProto.
|
||||||
* @param ret_types a vector of tensor types representing graph's
|
* @param ret_types a vector of tensor types representing graph's
|
||||||
* output tensor types.
|
* output tensor types.
|
||||||
* @param ret_vals a vector of mlir Value* representing graph's
|
* @param ret_vals a vector of mlir Value representing graph's
|
||||||
* output tensor.
|
* output tensor.
|
||||||
*/
|
*/
|
||||||
void ImportOutputTensor(const onnx::ValueInfoProto &output,
|
void ImportOutputTensor(const onnx::ValueInfoProto &output,
|
||||||
llvm::SmallVectorImpl<mlir::Type> &ret_types,
|
llvm::SmallVectorImpl<mlir::Type> &ret_types,
|
||||||
llvm::SmallVectorImpl<mlir::Value *> &ret_vals) {
|
llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
|
||||||
auto output_tensor_legalized_name = legalize_name(output.name());
|
auto output_tensor_legalized_name = legalize_name(output.name());
|
||||||
assert(
|
assert(
|
||||||
frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
|
frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
|
||||||
|
@ -722,7 +722,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||||
llvm::SmallVector<mlir::Value *, 4> ret_vals;
|
llvm::SmallVector<mlir::Value, 4> ret_vals;
|
||||||
// Import the output tensors
|
// Import the output tensors
|
||||||
for (const auto &output : graph.output()) {
|
for (const auto &output : graph.output()) {
|
||||||
ImportOutputTensor(output, ret_types, ret_vals);
|
ImportOutputTensor(output, ret_types, ret_vals);
|
||||||
|
|
|
@ -9,8 +9,9 @@ namespace onnf {
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
ParseResult
|
||||||
const Type& operandType, Value*& operand) {
|
KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
|
||||||
|
Value &operand) {
|
||||||
// If operand queue is empty, parse more operands and cache them.
|
// If operand queue is empty, parse more operands and cache them.
|
||||||
if (_operandRefQueue.empty()) {
|
if (_operandRefQueue.empty()) {
|
||||||
// Parse operand types:
|
// Parse operand types:
|
||||||
|
@ -27,7 +28,7 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||||
auto operand_ref = _operandRefQueue.front();
|
auto operand_ref = _operandRefQueue.front();
|
||||||
_operandRefQueue.pop();
|
_operandRefQueue.pop();
|
||||||
|
|
||||||
llvm::SmallVector<Value*, 1> operands;
|
llvm::SmallVector<Value, 1> operands;
|
||||||
_parser.resolveOperand(operand_ref, operandType, operands);
|
_parser.resolveOperand(operand_ref, operandType, operands);
|
||||||
operand = operands.front();
|
operand = operands.front();
|
||||||
return success();
|
return success();
|
||||||
|
@ -38,8 +39,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||||
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
|
const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) {
|
||||||
Value* operand = nullptr;
|
Value operand = nullptr;
|
||||||
if (ParseOptionalOperand(operandType, operand))
|
if (ParseOptionalOperand(operandType, operand))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -47,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult KrnlDialectOperandParser::ParseOperand(
|
ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType,
|
||||||
const Type& operandType, Value*& operand) {
|
Value &operand) {
|
||||||
if (ParseOptionalOperand(operandType, operand))
|
if (ParseOptionalOperand(operandType, operand))
|
||||||
return _parser.emitError(
|
return _parser.emitError(
|
||||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||||
|
@ -56,7 +57,7 @@ ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult KrnlDialectOperandParser::ParseOperand(
|
ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||||
const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) {
|
const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) {
|
||||||
if (ParseOptionalOperand(operandType, operandList))
|
if (ParseOptionalOperand(operandType, operandList))
|
||||||
return _parser.emitError(
|
return _parser.emitError(
|
||||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||||
|
@ -129,7 +130,7 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
|
||||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
}
|
}
|
||||||
|
|
||||||
void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) {
|
void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
|
||||||
if (boundMaps.size() % 2 == 0)
|
if (boundMaps.size() % 2 == 0)
|
||||||
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
||||||
AffineMap map = builder.getSymbolIdentityMap();
|
AffineMap map = builder.getSymbolIdentityMap();
|
||||||
|
|
|
@ -17,20 +17,22 @@ class KrnlDialectOperandParser {
|
||||||
: _parser(parser), _builder(parser.getBuilder()){};
|
: _parser(parser), _builder(parser.getBuilder()){};
|
||||||
|
|
||||||
// Parse an optional operand.
|
// Parse an optional operand.
|
||||||
mlir::ParseResult ParseOptionalOperand(
|
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
|
||||||
const mlir::Type& operandType, mlir::Value*& operand);
|
mlir::Value &operand);
|
||||||
|
|
||||||
// Parse an optional operand and push it to an operand list.
|
// Parse an optional operand and push it to an operand list.
|
||||||
mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType,
|
mlir::ParseResult
|
||||||
llvm::SmallVectorImpl<mlir::Value*>& operandList);
|
ParseOptionalOperand(const mlir::Type &operandType,
|
||||||
|
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||||
|
|
||||||
// Parse a required operand.
|
// Parse a required operand.
|
||||||
mlir::ParseResult ParseOperand(
|
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
|
||||||
const mlir::Type& operandType, mlir::Value*& operand);
|
mlir::Value &operand);
|
||||||
|
|
||||||
// Parse a required operand and push it to an operand list.
|
// Parse a required operand and push it to an operand list.
|
||||||
mlir::ParseResult ParseOperand(const mlir::Type& operandType,
|
mlir::ParseResult
|
||||||
llvm::SmallVectorImpl<mlir::Value*>& operandList);
|
ParseOperand(const mlir::Type &operandType,
|
||||||
|
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||||
|
|
||||||
// Do we have more operands to parse?
|
// Do we have more operands to parse?
|
||||||
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
|
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
|
||||||
|
@ -64,10 +66,9 @@ namespace mlir {
|
||||||
|
|
||||||
struct KrnlIterateOperandPack {
|
struct KrnlIterateOperandPack {
|
||||||
KrnlIterateOperandPack(mlir::Builder &builder,
|
KrnlIterateOperandPack(mlir::Builder &builder,
|
||||||
llvm::ArrayRef<mlir::Value*> inputLoops,
|
llvm::ArrayRef<mlir::Value> inputLoops,
|
||||||
llvm::ArrayRef<mlir::Value*> optimizedLoops)
|
llvm::ArrayRef<mlir::Value> optimizedLoops)
|
||||||
: builder(builder),
|
: builder(builder), inputLoops(inputLoops),
|
||||||
inputLoops(inputLoops),
|
|
||||||
optimizedLoops(optimizedLoops) {
|
optimizedLoops(optimizedLoops) {
|
||||||
_operands.insert(
|
_operands.insert(
|
||||||
_operands.end(), optimizedLoops.begin(), optimizedLoops.end());
|
_operands.end(), optimizedLoops.begin(), optimizedLoops.end());
|
||||||
|
@ -75,9 +76,9 @@ struct KrnlIterateOperandPack {
|
||||||
|
|
||||||
void pushConstantBound(int64_t bound);
|
void pushConstantBound(int64_t bound);
|
||||||
|
|
||||||
void pushOperandBound(mlir::Value* operand);
|
void pushOperandBound(mlir::Value operand);
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value*, 8> getOperands() const { return _operands; }
|
llvm::SmallVector<mlir::Value, 8> getOperands() const { return _operands; }
|
||||||
|
|
||||||
mlir::ArrayAttr getAttributes() const {
|
mlir::ArrayAttr getAttributes() const {
|
||||||
return builder.getArrayAttr(boundMaps);
|
return builder.getArrayAttr(boundMaps);
|
||||||
|
@ -90,11 +91,11 @@ struct KrnlIterateOperandPack {
|
||||||
private:
|
private:
|
||||||
int _boundIdx = 0;
|
int _boundIdx = 0;
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value*, 8> _operands;
|
llvm::SmallVector<mlir::Value, 8> _operands;
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Attribute, 8> boundMaps;
|
llvm::SmallVector<mlir::Attribute, 8> boundMaps;
|
||||||
|
|
||||||
llvm::ArrayRef<mlir::Value*> inputLoops, optimizedLoops;
|
llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
|
||||||
|
|
||||||
mlir::Builder& builder;
|
mlir::Builder& builder;
|
||||||
};
|
};
|
||||||
|
|
|
@ -44,21 +44,21 @@ static MemRefType convertTensorToMemRef(TensorType type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Insert an allocation and deallocation for the given MemRefType.
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
static Value *insertAllocAndDealloc(MemRefType type, Location loc,
|
static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
PatternRewriter &rewriter,
|
PatternRewriter &rewriter,
|
||||||
bool insertDealloc,
|
bool insertDealloc,
|
||||||
ArrayRef<Value *> operands = {}) {
|
ArrayRef<Value> operands = {}) {
|
||||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||||
AllocOp alloc;
|
AllocOp alloc;
|
||||||
if (!operands.empty()) {
|
if (!operands.empty()) {
|
||||||
auto memRefShape = type.getShape();
|
auto memRefShape = type.getShape();
|
||||||
auto rank = memRefShape.size();
|
auto rank = memRefShape.size();
|
||||||
|
|
||||||
std::map<int, Value *> fromOperands;
|
std::map<int, Value> fromOperands;
|
||||||
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
int memRefDimIdx = rank - 1 - reversedIdx;
|
int memRefDimIdx = rank - 1 - reversedIdx;
|
||||||
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
|
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
|
||||||
Value *maxDim = nullptr;
|
Value maxDim = nullptr;
|
||||||
for (int i = 0; i < operands.size(); i++) {
|
for (int i = 0; i < operands.size(); i++) {
|
||||||
auto operandShape =
|
auto operandShape =
|
||||||
operands[i]->getType().cast<MemRefType>().getShape();
|
operands[i]->getType().cast<MemRefType>().getShape();
|
||||||
|
@ -85,7 +85,7 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value *, 4> allocOperands;
|
SmallVector<Value, 4> allocOperands;
|
||||||
for (int i = 0; i < rank; ++i)
|
for (int i = 0; i < rank; ++i)
|
||||||
if (memRefShape[i] < 0)
|
if (memRefShape[i] < 0)
|
||||||
allocOperands.push_back(fromOperands[i]);
|
allocOperands.push_back(fromOperands[i]);
|
||||||
|
@ -146,14 +146,14 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
||||||
|
|
||||||
// Get run-time dimension information for unknown dimensions used for
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
// broadcasting.
|
// broadcasting.
|
||||||
std::map<int, std::map<int, Value *>>
|
std::map<int, std::map<int, Value>>
|
||||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
MemRefType memRefType, ArrayRef<Value *> operands) {
|
MemRefType memRefType, ArrayRef<Value> operands) {
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
int64_t rank = memRefShape.size();
|
int64_t rank = memRefShape.size();
|
||||||
// For unknown dimensions, we need to get dimension values at runtime in
|
// For unknown dimensions, we need to get dimension values at runtime in
|
||||||
// order to do broadcasting.
|
// order to do broadcasting.
|
||||||
std::map<int, std::map<int, Value *>> DimInfo;
|
std::map<int, std::map<int, Value>> DimInfo;
|
||||||
// For each result dimension, compute the number of sharing operands.
|
// For each result dimension, compute the number of sharing operands.
|
||||||
// Sharing operands are operands sharing the same index (counting from the
|
// Sharing operands are operands sharing the same index (counting from the
|
||||||
// rightmost to the leftmost) for a given dimension.
|
// rightmost to the leftmost) for a given dimension.
|
||||||
|
@ -173,7 +173,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
// We only care about unknown dimensions whose number of sharing operands is
|
// We only care about unknown dimensions whose number of sharing operands is
|
||||||
// more than one, since they are potentially broadcasted dimensions.
|
// more than one, since they are potentially broadcasted dimensions.
|
||||||
for (int i = 0; i < operands.size(); ++i) {
|
for (int i = 0; i < operands.size(); ++i) {
|
||||||
std::map<int, Value *> broadcastedDims;
|
std::map<int, Value> broadcastedDims;
|
||||||
auto shape = operands[i]->getType().cast<MemRefType>().getShape();
|
auto shape = operands[i]->getType().cast<MemRefType>().getShape();
|
||||||
int size = shape.size();
|
int size = shape.size();
|
||||||
for (int j = 0; j < shape.size(); ++j) {
|
for (int j = 0; j < shape.size(); ++j) {
|
||||||
|
@ -192,17 +192,17 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
|
||||||
// Extract induction variables that are used for broadcasting values of a
|
// Extract induction variables that are used for broadcasting values of a
|
||||||
// given operand.
|
// given operand.
|
||||||
std::vector<Value *>
|
std::vector<Value>
|
||||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value *> loopIVs, Value *operand,
|
ArrayRef<Value> loopIVs, Value operand,
|
||||||
std::map<int, Value *> broadcastedDims) {
|
std::map<int, Value> broadcastedDims) {
|
||||||
// `operand` must has a ranked type. This should have been checked by the
|
// `operand` must has a ranked type. This should have been checked by the
|
||||||
// shape inference pass.
|
// shape inference pass.
|
||||||
auto operandShape = operand->getType().cast<MemRefType>().getShape();
|
auto operandShape = operand->getType().cast<MemRefType>().getShape();
|
||||||
auto rank = operandShape.size();
|
auto rank = operandShape.size();
|
||||||
auto loopCount = loopIVs.size();
|
auto loopCount = loopIVs.size();
|
||||||
|
|
||||||
std::vector<Value *> newLoopIVs;
|
std::vector<Value> newLoopIVs;
|
||||||
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
auto dimIdx = rank - 1 - reversedIdx;
|
auto dimIdx = rank - 1 - reversedIdx;
|
||||||
auto loopIdx = loopCount - 1 - reversedIdx;
|
auto loopIdx = loopCount - 1 - reversedIdx;
|
||||||
|
@ -247,7 +247,7 @@ struct ScalarOp<ONNXMulOp> {
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<ONNXDivOp> {
|
struct ScalarOp<ONNXDivOp> {
|
||||||
using FOp = DivFOp;
|
using FOp = DivFOp;
|
||||||
using IOp = DivISOp;
|
using IOp = SignedDivIOp;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -295,8 +295,8 @@ using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp;
|
||||||
// Scalar unary ops for lowering to Krnl dialect.
|
// Scalar unary ops for lowering to Krnl dialect.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <typename UnaryOp>
|
template <typename UnaryOp>
|
||||||
Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value *> operands,
|
ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
/* Lower UnaryOp to Ops in the Standard dialect.
|
/* Lower UnaryOp to Ops in the Standard dialect.
|
||||||
*/
|
*/
|
||||||
|
@ -318,14 +318,13 @@ Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXTanhOp
|
// Scalar unary ops for lowering ONNXTanhOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op,
|
Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Value> operands,
|
||||||
ArrayRef<Value *> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
|
@ -342,14 +341,13 @@ Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op,
|
||||||
// Scalar unary ops for lowering ONNXSinhOp
|
// Scalar unary ops for lowering ONNXSinhOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op,
|
Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Value> operands,
|
||||||
ArrayRef<Value *> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// ConstantOp 2)
|
// ConstantOp 2)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
||||||
|
@ -366,14 +364,13 @@ Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op,
|
||||||
// Scalar unary ops for lowering ONNXCoshOp
|
// Scalar unary ops for lowering ONNXCoshOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op,
|
Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Value> operands,
|
||||||
ArrayRef<Value *> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// ConstantOp 2)
|
// ConstantOp 2)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
||||||
|
@ -390,14 +387,14 @@ Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op,
|
||||||
// Scalar unary ops for lowering ONNXSigmoidOp
|
// Scalar unary ops for lowering ONNXSigmoidOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value *> operands,
|
ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
|
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
|
||||||
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||||
|
@ -413,8 +410,8 @@ Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
||||||
// Scalar unary ops for lowering ONNXHardSigmoidOp
|
// Scalar unary ops for lowering ONNXHardSigmoidOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||||
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value *> operands,
|
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// %Y = AddFOp(MulFOp(alpha, %X), beta)
|
// %Y = AddFOp(MulFOp(alpha, %X), beta)
|
||||||
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
|
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
|
||||||
|
@ -424,7 +421,7 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||||
// %Z,
|
// %Z,
|
||||||
// Constant 1)
|
// Constant 1)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
|
||||||
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
|
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
|
||||||
|
|
||||||
|
@ -449,14 +446,14 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||||
// Scalar unary ops for lowering ONNXEluOp
|
// Scalar unary ops for lowering ONNXEluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value *> operands,
|
ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
|
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
@ -478,15 +475,14 @@ Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXReluOp
|
// Scalar unary ops for lowering ONNXReluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op,
|
Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Value> operands,
|
||||||
ArrayRef<Value *> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
// ConstantOp 0,
|
// ConstantOp 0,
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
|
@ -500,15 +496,15 @@ Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op,
|
||||||
// Scalar unary ops for lowering ONNXLeakyReluOp
|
// Scalar unary ops for lowering ONNXLeakyReluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *
|
Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
|
||||||
mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value *> operands,
|
ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
// MulFOp(alpha, %X),
|
// MulFOp(alpha, %X),
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
@ -525,9 +521,8 @@ mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXSeluOp
|
// Scalar unary ops for lowering ONNXSeluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op,
|
Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Value> operands,
|
||||||
ArrayRef<Value *> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
|
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
|
||||||
// MulFOp(gamma, %X),
|
// MulFOp(gamma, %X),
|
||||||
|
@ -535,7 +530,7 @@ Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op,
|
||||||
// SubFOp(MulFOp(alpha, ExpOp(%X)),
|
// SubFOp(MulFOp(alpha, ExpOp(%X)),
|
||||||
// alpha)))
|
// alpha)))
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
|
||||||
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
|
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
|
||||||
|
|
||||||
|
@ -558,13 +553,12 @@ Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op,
|
||||||
// Scalar unary ops for lowering ONNXReciprocalOp
|
// Scalar unary ops for lowering ONNXReciprocalOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *
|
Value mapToLowerScalarOp<ONNXReciprocalOp>(
|
||||||
mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types,
|
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||||
ArrayRef<Value *> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
|
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
|
||||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||||
auto result = rewriter.create<DivFOp>(loc, one, operand);
|
auto result = rewriter.create<DivFOp>(loc, one, operand);
|
||||||
|
@ -576,15 +570,15 @@ mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXMaxOp
|
// Scalar unary ops for lowering ONNXMaxOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
|
Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value *> operands,
|
ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
|
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
|
||||||
// %X,
|
// %X,
|
||||||
// %Y)
|
// %Y)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *lhs = operands[0];
|
Value lhs = operands[0];
|
||||||
Value *rhs = operands[1];
|
Value rhs = operands[1];
|
||||||
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
||||||
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
||||||
return result;
|
return result;
|
||||||
|
@ -594,15 +588,15 @@ Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXMinOp
|
// Scalar unary ops for lowering ONNXMinOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value *mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
|
Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value *> operands,
|
ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
|
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
|
||||||
// %X,
|
// %X,
|
||||||
// %Y)
|
// %Y)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value *lhs = operands[0];
|
Value lhs = operands[0];
|
||||||
Value *rhs = operands[1];
|
Value rhs = operands[1];
|
||||||
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
||||||
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
||||||
return result;
|
return result;
|
||||||
|
@ -615,7 +609,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
|
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
// TODO: Check that the types are valid.
|
// TODO: Check that the types are valid.
|
||||||
// An element-wise unary operation must have all operands and the result of
|
// An element-wise unary operation must have all operands and the result of
|
||||||
|
@ -632,7 +626,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
// dimensions with the result at this pre-optimization phase.
|
// dimensions with the result at this pre-optimization phase.
|
||||||
// TODO: verify that dimensions match.
|
// TODO: verify that dimensions match.
|
||||||
// TODO: can the dimension of the result differ after optimizations?
|
// TODO: can the dimension of the result differ after optimizations?
|
||||||
Value *alloc;
|
Value alloc;
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
@ -647,7 +641,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Define loops.
|
// Define loops.
|
||||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
||||||
std::vector<Value *> originalLoops;
|
std::vector<Value> originalLoops;
|
||||||
originalLoops.reserve(rank);
|
originalLoops.reserve(rank);
|
||||||
for (auto result : loopsOp.getResults()) {
|
for (auto result : loopsOp.getResults()) {
|
||||||
originalLoops.push_back(result);
|
originalLoops.push_back(result);
|
||||||
|
@ -655,7 +649,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Define loop optimization.
|
// Define loop optimization.
|
||||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
||||||
std::vector<Value *> optimizedLoops;
|
std::vector<Value> optimizedLoops;
|
||||||
optimizedLoops.reserve(rank);
|
optimizedLoops.reserve(rank);
|
||||||
for (auto result : optimizedLoopsOp.getResults()) {
|
for (auto result : optimizedLoopsOp.getResults()) {
|
||||||
optimizedLoops.push_back(result);
|
optimizedLoops.push_back(result);
|
||||||
|
@ -695,7 +689,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
rewriter.setInsertionPointToStart(&iterationBlock);
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
// Handle the operation:
|
// Handle the operation:
|
||||||
SmallVector<Value *, 4> loopIVs;
|
SmallVector<Value, 4> loopIVs;
|
||||||
for (auto arg : iterationBlock.getArguments())
|
for (auto arg : iterationBlock.getArguments())
|
||||||
loopIVs.push_back(arg);
|
loopIVs.push_back(arg);
|
||||||
|
|
||||||
|
@ -718,7 +712,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
|
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
// TODO: Check that the types are valid.
|
// TODO: Check that the types are valid.
|
||||||
// An element-wise variadic operation must have all operands and the result
|
// An element-wise variadic operation must have all operands and the result
|
||||||
|
@ -730,7 +724,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
|
||||||
Value *alloc;
|
Value alloc;
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
// If the output has a dynamic dimension, we compute its dimension at
|
// If the output has a dynamic dimension, we compute its dimension at
|
||||||
// runtime by using dimensions from the operands.
|
// runtime by using dimensions from the operands.
|
||||||
|
@ -749,7 +743,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Define loops.
|
// Define loops.
|
||||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
||||||
std::vector<Value *> originalLoops;
|
std::vector<Value> originalLoops;
|
||||||
originalLoops.reserve(rank);
|
originalLoops.reserve(rank);
|
||||||
for (auto result : loopsOp.getResults()) {
|
for (auto result : loopsOp.getResults()) {
|
||||||
originalLoops.push_back(result);
|
originalLoops.push_back(result);
|
||||||
|
@ -757,7 +751,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Define loop optimization.
|
// Define loop optimization.
|
||||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
||||||
std::vector<Value *> optimizedLoops;
|
std::vector<Value> optimizedLoops;
|
||||||
optimizedLoops.reserve(rank);
|
optimizedLoops.reserve(rank);
|
||||||
for (auto result : optimizedLoopsOp.getResults()) {
|
for (auto result : optimizedLoopsOp.getResults()) {
|
||||||
optimizedLoops.push_back(result);
|
optimizedLoops.push_back(result);
|
||||||
|
@ -781,7 +775,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Get run-time dimension information for unknown dimensions used for
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
// broadcasting.
|
// broadcasting.
|
||||||
std::map<int, std::map<int, Value *>> broadcastedDimInfo =
|
std::map<int, std::map<int, Value>> broadcastedDimInfo =
|
||||||
getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
|
getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
|
||||||
|
|
||||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
|
@ -801,12 +795,12 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
rewriter.setInsertionPointToStart(&iterationBlock);
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
// Handle the operation:
|
// Handle the operation:
|
||||||
SmallVector<Value *, 4> loopIVs;
|
SmallVector<Value, 4> loopIVs;
|
||||||
for (auto arg : iterationBlock.getArguments())
|
for (auto arg : iterationBlock.getArguments())
|
||||||
loopIVs.push_back(arg);
|
loopIVs.push_back(arg);
|
||||||
|
|
||||||
// Fold over operands for each of their scalar values
|
// Fold over operands for each of their scalar values
|
||||||
Value *accumulated, *next;
|
Value accumulated, next;
|
||||||
auto accumulatedLoopIVs = getLoopIVsForBroadcasting(
|
auto accumulatedLoopIVs = getLoopIVsForBroadcasting(
|
||||||
loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]);
|
loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]);
|
||||||
accumulated = rewriter.create<LoadOp>(loc, operands[0], accumulatedLoopIVs);
|
accumulated = rewriter.create<LoadOp>(loc, operands[0], accumulatedLoopIVs);
|
||||||
|
@ -831,17 +825,17 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
Value *alloc;
|
Value alloc;
|
||||||
|
|
||||||
// Compute size in bytes.
|
// Compute size in bytes.
|
||||||
Value *tensorSize = rewriter.create<ConstantOp>(
|
Value tensorSize = rewriter.create<ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
getMemRefEltSizeInBytes(memRefType)));
|
getMemRefEltSizeInBytes(memRefType)));
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
@ -849,14 +843,14 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
} else {
|
} else {
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
SmallVector<Value *, 4> allocOperands;
|
SmallVector<Value, 4> allocOperands;
|
||||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||||
// The shape array can always be used to construct shape information of
|
// The shape array can always be used to construct shape information of
|
||||||
// the result.
|
// the result.
|
||||||
Value *index = rewriter.create<ConstantOp>(
|
Value index = rewriter.create<ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||||
Value *loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||||
Value *int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
Value int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
||||||
loc, loadedVal, rewriter.getIntegerType(64));
|
loc, loadedVal, rewriter.getIntegerType(64));
|
||||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
|
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
|
||||||
allocOperands.push_back(rewriter.create<IndexCastOp>(
|
allocOperands.push_back(rewriter.create<IndexCastOp>(
|
||||||
|
|
|
@ -30,7 +30,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
||||||
operandItr++;
|
operandItr++;
|
||||||
|
|
||||||
// Organize operands into lower/upper bounds in affine.for ready formats.
|
// Organize operands into lower/upper bounds in affine.for ready formats.
|
||||||
SmallVector<Value *, 4> lbOperands, ubOperands;
|
SmallVector<Value, 4> lbOperands, ubOperands;
|
||||||
AffineMap lbMap, ubMap;
|
AffineMap lbMap, ubMap;
|
||||||
for (int boundType = 0; boundType < 2; boundType++) {
|
for (int boundType = 0; boundType < 2; boundType++) {
|
||||||
auto &operands = boundType == 0 ? lbOperands : ubOperands;
|
auto &operands = boundType == 0 ? lbOperands : ubOperands;
|
||||||
|
|
|
@ -51,7 +51,7 @@ public:
|
||||||
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto *context = op->getContext();
|
auto *context = op->getContext();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
@ -66,27 +66,27 @@ public:
|
||||||
// First operand.
|
// First operand.
|
||||||
Type dstType =
|
Type dstType =
|
||||||
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||||
Value *alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
|
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
|
||||||
Value *alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
||||||
|
|
||||||
// Second operand.
|
// Second operand.
|
||||||
Type srcType =
|
Type srcType =
|
||||||
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||||
Value *alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
|
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
|
||||||
Value *alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
||||||
|
|
||||||
// Size.
|
// Size.
|
||||||
Value *int64Size = rewriter.create<LLVM::SExtOp>(
|
Value int64Size = rewriter.create<LLVM::SExtOp>(
|
||||||
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
||||||
|
|
||||||
// Memcpy call
|
// Memcpy call
|
||||||
rewriter.create<CallOp>(
|
rewriter.create<CallOp>(
|
||||||
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||||
ArrayRef<Value *>(
|
ArrayRef<Value>(
|
||||||
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
|
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
|
||||||
|
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
@ -210,7 +210,7 @@ public:
|
||||||
|
|
||||||
// Retrieve dynamic mem refs from wrapped input, and convert every one of
|
// Retrieve dynamic mem refs from wrapped input, and convert every one of
|
||||||
// them to static mem refs.
|
// them to static mem refs.
|
||||||
SmallVector<Value *, 4> staticInputs;
|
SmallVector<Value, 4> staticInputs;
|
||||||
auto wrappedInput = entryPointEntryBlock.getArgument(0);
|
auto wrappedInput = entryPointEntryBlock.getArgument(0);
|
||||||
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
|
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
|
||||||
// Call API function to retrieve the i-th dynamic memref.
|
// Call API function to retrieve the i-th dynamic memref.
|
||||||
|
@ -225,13 +225,12 @@ public:
|
||||||
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
||||||
auto one = rewriter.create<LLVM::ConstantOp>(
|
auto one = rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, int32Ty, rewriter.getI32IntegerAttr(1));
|
loc, int32Ty, rewriter.getI32IntegerAttr(1));
|
||||||
Value *ptrToMemRef =
|
Value ptrToMemRef = rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
|
||||||
rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
|
|
||||||
/*alignment=*/0);
|
/*alignment=*/0);
|
||||||
|
|
||||||
// Fill in the memref underlying ptrToMemRef with information extracted
|
// Fill in the memref underlying ptrToMemRef with information extracted
|
||||||
// from dynMemRef.
|
// from dynMemRef.
|
||||||
fillPtrToMemRefWithDynMemRef(*dynMemRef, *ptrToMemRef, rewriter, loc,
|
fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc,
|
||||||
apiRegistry, llvmDialect);
|
apiRegistry, llvmDialect);
|
||||||
|
|
||||||
// ptrToMemRef will be an input to main computation graph function.
|
// ptrToMemRef will be an input to main computation graph function.
|
||||||
|
@ -261,8 +260,8 @@ public:
|
||||||
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||||
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
||||||
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
||||||
fillDynMemRefWithMemRef(*outMemRef, *outDynMemRef, rewriter, loc,
|
fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry,
|
||||||
apiRegistry, llvmDialect);
|
llvmDialect);
|
||||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, int32Ty, rewriter.getI32IntegerAttr(0));
|
loc, int32Ty, rewriter.getI32IntegerAttr(0));
|
||||||
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
||||||
|
@ -270,7 +269,7 @@ public:
|
||||||
|
|
||||||
// Return wrapped output.
|
// Return wrapped output.
|
||||||
rewriter.create<LLVM::ReturnOp>(loc,
|
rewriter.create<LLVM::ReturnOp>(loc,
|
||||||
SmallVector<Value *, 1>({wrappedOutput}));
|
SmallVector<Value, 1>({wrappedOutput}));
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -315,11 +314,11 @@ private:
|
||||||
|
|
||||||
// Call a registered API, return the return SSA values if only one result is
|
// Call a registered API, return the return SSA values if only one result is
|
||||||
// returned, otherwise return nullptr.
|
// returned, otherwise return nullptr.
|
||||||
Value *callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
||||||
API apiId, ArrayRef<Value *> params) const {
|
API apiId, ArrayRef<Value> params) const {
|
||||||
auto returnVals = rewriter.create<LLVM::CallOp>(
|
auto returnVals = rewriter.create<LLVM::CallOp>(
|
||||||
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
|
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
|
||||||
ArrayRef<Value *>(params));
|
ArrayRef<Value>(params));
|
||||||
if (returnVals.getNumResults() == 1)
|
if (returnVals.getNumResults() == 1)
|
||||||
return returnVals.getResult(0);
|
return returnVals.getResult(0);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -348,12 +347,11 @@ private:
|
||||||
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
||||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
|
||||||
Value *memRef =
|
Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, ptrToMemRef);
|
||||||
rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, &ptrToMemRef);
|
|
||||||
|
|
||||||
// Set dataPtr and alignedDataPtr;
|
// Set dataPtr and alignedDataPtr;
|
||||||
auto dataPtr =
|
auto dataPtr =
|
||||||
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {&dynMemRef});
|
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
|
||||||
dataPtr = rewriter.create<LLVM::BitcastOp>(
|
dataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, memRefTy.getStructElementType(0), dataPtr);
|
loc, memRefTy.getStructElementType(0), dataPtr);
|
||||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
@ -373,9 +371,9 @@ private:
|
||||||
// Get rank, sizes array ptr and strides array ptr.
|
// Get rank, sizes array ptr and strides array ptr.
|
||||||
auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
|
auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
|
||||||
auto sizesArrayPtr =
|
auto sizesArrayPtr =
|
||||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&dynMemRef});
|
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
|
||||||
auto stridesArrayPtr =
|
auto stridesArrayPtr =
|
||||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&dynMemRef});
|
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {dynMemRef});
|
||||||
|
|
||||||
for (decltype(rank) i = 0; i < rank; i++) {
|
for (decltype(rank) i = 0; i < rank; i++) {
|
||||||
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
@ -384,7 +382,7 @@ private:
|
||||||
// Insert size of the dimension.
|
// Insert size of the dimension.
|
||||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
||||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||||
ArrayRef<Value *>({dimIdx}));
|
ArrayRef<Value>({dimIdx}));
|
||||||
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
|
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
|
||||||
dimSizePtr);
|
dimSizePtr);
|
||||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
@ -395,7 +393,7 @@ private:
|
||||||
// Insert stride of the dimension.
|
// Insert stride of the dimension.
|
||||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
||||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||||
ArrayRef<Value *>({dimIdx}));
|
ArrayRef<Value>({dimIdx}));
|
||||||
auto dimStride = rewriter.create<LLVM::LoadOp>(
|
auto dimStride = rewriter.create<LLVM::LoadOp>(
|
||||||
loc, int64Ty.getPointerTo(), dimStridePtr);
|
loc, int64Ty.getPointerTo(), dimStridePtr);
|
||||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
@ -404,7 +402,7 @@ private:
|
||||||
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.create<LLVM::StoreOp>(loc, memRef, &ptrToMemRef);
|
rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
|
||||||
}
|
}
|
||||||
|
|
||||||
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
|
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
|
||||||
|
@ -415,19 +413,19 @@ private:
|
||||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
|
||||||
// Extract the data pointer, and record it in dynamic mem ref created.
|
// Extract the data pointer, and record it in dynamic mem ref created.
|
||||||
Value *outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
|
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, outMemRefTy.getStructElementType(0), &outMemRef,
|
loc, outMemRefTy.getStructElementType(0), outMemRef,
|
||||||
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
|
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
|
||||||
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
||||||
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
||||||
{&outDynMemRef, outMemRefDataPtr});
|
{outDynMemRef, outMemRefDataPtr});
|
||||||
|
|
||||||
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
|
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
|
||||||
auto sizesArrayPtr =
|
auto sizesArrayPtr =
|
||||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&outDynMemRef});
|
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
|
||||||
auto stridesArrayPtr =
|
auto stridesArrayPtr =
|
||||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&outDynMemRef});
|
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outDynMemRef});
|
||||||
|
|
||||||
for (decltype(rank) i = 0; i < rank; i++) {
|
for (decltype(rank) i = 0; i < rank; i++) {
|
||||||
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
@ -435,22 +433,22 @@ private:
|
||||||
|
|
||||||
// Transfer size of dimension from memref to dynamic memref.
|
// Transfer size of dimension from memref to dynamic memref.
|
||||||
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
|
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, int64Ty, &outMemRef,
|
loc, int64Ty, outMemRef,
|
||||||
rewriter.getArrayAttr(
|
rewriter.getArrayAttr(
|
||||||
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
||||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
||||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||||
ArrayRef<Value *>({dimIdx}));
|
ArrayRef<Value>({dimIdx}));
|
||||||
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
|
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
|
||||||
|
|
||||||
// Transfer stride of dimension from memref to dynamic memref.
|
// Transfer stride of dimension from memref to dynamic memref.
|
||||||
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
|
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, int64Ty, &outMemRef,
|
loc, int64Ty, outMemRef,
|
||||||
rewriter.getArrayAttr(
|
rewriter.getArrayAttr(
|
||||||
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
||||||
loc, int64Ty.getPointerTo(), stridesArrayPtr,
|
loc, int64Ty.getPointerTo(), stridesArrayPtr,
|
||||||
ArrayRef<Value *>({dimIdx}));
|
ArrayRef<Value>({dimIdx}));
|
||||||
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
|
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue