[NFC] Change structure of conversion folder. (#96)
* Change structure of conversion folder. * Fix comments. Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
32f08bcf0c
commit
ee3e140ddb
|
@ -62,7 +62,21 @@ target_include_directories(onnf_shape_inference
|
||||||
target_link_libraries(onnf_shape_inference ${MLIRLibs})
|
target_link_libraries(onnf_shape_inference ${MLIRLibs})
|
||||||
add_dependencies(onnf_shape_inference gen_krnl_ops)
|
add_dependencies(onnf_shape_inference gen_krnl_ops)
|
||||||
|
|
||||||
add_library(onnf_lower_frontend conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
|
add_library(onnf_lower_frontend
|
||||||
|
conversion/onnx_to_krnl/onnx_to_krnl_common.cpp
|
||||||
|
conversion/onnx_to_krnl/onnx_to_krnl_common.hpp
|
||||||
|
conversion/onnx_to_krnl/math/elementwise.cpp
|
||||||
|
conversion/onnx_to_krnl/math/gemm.cpp
|
||||||
|
conversion/onnx_to_krnl/math/matmul.cpp
|
||||||
|
conversion/onnx_to_krnl/math/reduction.cpp
|
||||||
|
conversion/onnx_to_krnl/math/softmax.cpp
|
||||||
|
conversion/onnx_to_krnl/nn/conv.cpp
|
||||||
|
conversion/onnx_to_krnl/nn/normalization.cpp
|
||||||
|
conversion/onnx_to_krnl/tensor/identity.cpp
|
||||||
|
conversion/onnx_to_krnl/tensor/reshape.cpp
|
||||||
|
conversion/onnx_to_krnl/tensor/transpose.cpp
|
||||||
|
conversion/onnx_to_krnl/tensor/unsqueeze.cpp
|
||||||
|
conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
|
||||||
target_include_directories(onnf_lower_frontend
|
target_include_directories(onnf_lower_frontend
|
||||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
${ONNF_SRC_ROOT})
|
${ONNF_SRC_ROOT})
|
||||||
|
|
|
@ -8,404 +8,11 @@
|
||||||
// Krnl IR and standard operations.
|
// Krnl IR and standard operations.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
|
||||||
#include "llvm/ADT/Sequence.h"
|
|
||||||
|
|
||||||
#include "src/dialect/krnl/krnl_helper.hpp"
|
|
||||||
#include "src/dialect/krnl/krnl_ops.hpp"
|
|
||||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
|
||||||
#include "src/pass/passes.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// FrontendToAffine RewritePatterns
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
/// Check is all dimensions are known at compile time.
|
|
||||||
static bool hasAllConstantDimensions(MemRefType type) {
|
|
||||||
auto memRefShape = type.getShape();
|
|
||||||
for (int i = 0; i < memRefShape.size(); ++i)
|
|
||||||
if (memRefShape[i] < 0)
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
|
||||||
static MemRefType convertToMemRefType(Type type) {
|
|
||||||
MemRefType memRefType;
|
|
||||||
auto tensorType = type.dyn_cast<TensorType>();
|
|
||||||
if (tensorType) {
|
|
||||||
assert(tensorType.hasRank() && "expected only ranked shapes");
|
|
||||||
memRefType =
|
|
||||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
|
||||||
} else {
|
|
||||||
memRefType = type.dyn_cast<MemRefType>();
|
|
||||||
}
|
|
||||||
return memRefType;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Insert an allocation and deallocation for the given MemRefType.
|
|
||||||
static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|
||||||
PatternRewriter &rewriter,
|
|
||||||
bool insertDealloc,
|
|
||||||
ArrayRef<Value> operands = {}) {
|
|
||||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
|
||||||
AllocOp alloc;
|
|
||||||
if (!operands.empty()) {
|
|
||||||
auto memRefShape = type.getShape();
|
|
||||||
auto rank = memRefShape.size();
|
|
||||||
|
|
||||||
std::map<int, Value> fromOperands;
|
|
||||||
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
|
||||||
int memRefDimIdx = rank - 1 - reversedIdx;
|
|
||||||
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
|
|
||||||
Value maxDim = nullptr;
|
|
||||||
for (int i = 0; i < operands.size(); i++) {
|
|
||||||
auto operandShape =
|
|
||||||
operands[i].getType().cast<MemRefType>().getShape();
|
|
||||||
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
|
|
||||||
|
|
||||||
if (operandDimIdx < 0)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// In case of operations with broadcasting, the dimension of the
|
|
||||||
// alloc result is the maximum size along each dimension of the
|
|
||||||
// operands.
|
|
||||||
auto operandDim =
|
|
||||||
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
|
|
||||||
if (maxDim) {
|
|
||||||
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
|
|
||||||
operandDim, maxDim);
|
|
||||||
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
|
|
||||||
maxDim);
|
|
||||||
} else {
|
|
||||||
maxDim = operandDim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 4> allocOperands;
|
|
||||||
for (int i = 0; i < rank; ++i)
|
|
||||||
if (memRefShape[i] < 0)
|
|
||||||
allocOperands.push_back(fromOperands[i]);
|
|
||||||
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
|
|
||||||
} else {
|
|
||||||
alloc = rewriter.create<AllocOp>(loc, type);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure to allocate at the beginning of the block if
|
|
||||||
// all dimensions are known.
|
|
||||||
auto *parentBlock = alloc.getOperation()->getBlock();
|
|
||||||
if (hasAllConstantDimensions(type))
|
|
||||||
alloc.getOperation()->moveBefore(&parentBlock->front());
|
|
||||||
|
|
||||||
if (insertDealloc) {
|
|
||||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
|
||||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
|
||||||
}
|
|
||||||
|
|
||||||
return alloc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine if current function returns the result value of the
|
|
||||||
// current op being lowered. If it does then dealloc should not be
|
|
||||||
// inserted.
|
|
||||||
static bool checkInsertDealloc(Operation *currentOp) {
|
|
||||||
auto parentBlock = currentOp->getBlock();
|
|
||||||
|
|
||||||
bool insertDealloc = true;
|
|
||||||
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
|
||||||
assert(currentOp->getNumResults() < 2 &&
|
|
||||||
"No more than one result supported (for now).");
|
|
||||||
// If there is at least one result to investigate.
|
|
||||||
if (currentOp->getNumResults() > 0) {
|
|
||||||
auto result = currentOp->getResult(0);
|
|
||||||
for (const auto &operand : op.getOperands())
|
|
||||||
if (operand == result)
|
|
||||||
insertDealloc = false;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return insertDealloc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a mapping from result type's dimensions to input type's dimensions,
|
|
||||||
// given that the result type is the result of a reduction op over the input
|
|
||||||
// type.
|
|
||||||
std::map<int64_t, int64_t>
|
|
||||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
|
||||||
std::map<int64_t, int64_t> OutInDimMap;
|
|
||||||
int64_t rank = inputTy.getRank();
|
|
||||||
|
|
||||||
// Mark reduction axes.
|
|
||||||
std::vector<bool> isReductionAxis;
|
|
||||||
for (decltype(rank) i = 0; i < rank; ++i) {
|
|
||||||
if (std::find(axes.begin(), axes.end(), i) != axes.end())
|
|
||||||
isReductionAxis.push_back(true);
|
|
||||||
else
|
|
||||||
isReductionAxis.push_back(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
|
|
||||||
// If it is a reduction axis, there is no relationship among dimensions.
|
|
||||||
if (isReductionAxis[inIndex]) {
|
|
||||||
if (keepdims)
|
|
||||||
outIndex++;
|
|
||||||
} else {
|
|
||||||
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
|
|
||||||
outIndex++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return OutInDimMap;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add bounds associated with the op operand to the KRNL iteration pack.
|
|
||||||
// Dynamic dimenions are supported.
|
|
||||||
static void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
|
||||||
Location loc, KrnlIterateOperandPack &pack,
|
|
||||||
Value operand, int index) {
|
|
||||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
|
||||||
if (shape[index] < 0) {
|
|
||||||
pack.pushConstantBound(0);
|
|
||||||
pack.pushOperandBound(
|
|
||||||
rewriter.create<DimOp>(loc, operand, index).getResult());
|
|
||||||
} else {
|
|
||||||
pack.pushConstantBound(0);
|
|
||||||
pack.pushConstantBound(shape[index]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Function that defines the KRNL dialect loops and their respective
|
|
||||||
// optimized version.
|
|
||||||
static KrnlOptimizeLoopsOp
|
|
||||||
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
|
||||||
std::vector<Value> &loops,
|
|
||||||
std::vector<Value> &optimizedLoops, int64_t numLoops) {
|
|
||||||
// Define loops.
|
|
||||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
|
|
||||||
loops.reserve(numLoops);
|
|
||||||
for (auto result : loopsOp.getResults())
|
|
||||||
loops.push_back(result);
|
|
||||||
|
|
||||||
// Define optimized version of the loops.
|
|
||||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
|
|
||||||
optimizedLoops.reserve(numLoops);
|
|
||||||
for (auto result : optimizedLoopsOp.getResults())
|
|
||||||
optimizedLoops.push_back(result);
|
|
||||||
|
|
||||||
return optimizedLoopsOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Function that emits the loops and their optimized version.
|
|
||||||
// The function returns a reference to the inner optimization block.
|
|
||||||
static Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
|
||||||
std::vector<Value> &loops,
|
|
||||||
std::vector<Value> &optimizedLoops,
|
|
||||||
int64_t numLoops) {
|
|
||||||
KrnlOptimizeLoopsOp optimizedLoopsOp =
|
|
||||||
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
|
|
||||||
return &optimizedLoopsOp.region().front();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Function which emits a basic set of loops and optimized loops
|
|
||||||
// for a given operation argument. A reference to the loop optimization
|
|
||||||
// block is returned in the last argument of the function.
|
|
||||||
static void emitKrnlLoopsAndIterationForOperand(
|
|
||||||
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
|
||||||
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
|
||||||
KrnlIterateOp &iterateOp) {
|
|
||||||
// Operand shape.
|
|
||||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
|
||||||
|
|
||||||
// Number of loops.
|
|
||||||
int64_t rank = shape.size();
|
|
||||||
|
|
||||||
// Define loops and optimized loops.
|
|
||||||
std::vector<Value> optimizedLoops;
|
|
||||||
optimizedLoopsOp =
|
|
||||||
emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
|
||||||
|
|
||||||
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
|
||||||
// Iterate over the loop nest.
|
|
||||||
for (int i = 0; i < rank; ++i)
|
|
||||||
addDimensionToPack(rewriter, loc, pack, operand, i);
|
|
||||||
|
|
||||||
iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
|
||||||
auto elementType = memRefType.getElementType();
|
|
||||||
|
|
||||||
unsigned sizeInBits;
|
|
||||||
if (elementType.isIntOrFloat()) {
|
|
||||||
sizeInBits = elementType.getIntOrFloatBitWidth();
|
|
||||||
} else {
|
|
||||||
auto vectorType = elementType.cast<VectorType>();
|
|
||||||
sizeInBits =
|
|
||||||
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
|
|
||||||
}
|
|
||||||
return llvm::divideCeil(sizeInBits, 8);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get run-time dimension information for unknown dimensions used for
|
|
||||||
// broadcasting.
|
|
||||||
std::map<int, std::map<int, Value>>
|
|
||||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
|
||||||
MemRefType memRefType, ArrayRef<Value> operands) {
|
|
||||||
auto memRefShape = memRefType.getShape();
|
|
||||||
int64_t rank = memRefShape.size();
|
|
||||||
// For unknown dimensions, we need to get dimension values at runtime in
|
|
||||||
// order to do broadcasting.
|
|
||||||
std::map<int, std::map<int, Value>> DimInfo;
|
|
||||||
// For each result dimension, compute the number of sharing operands.
|
|
||||||
// Sharing operands are operands sharing the same index (counting from the
|
|
||||||
// rightmost to the leftmost) for a given dimension.
|
|
||||||
std::map<int, int> sharedDimCount;
|
|
||||||
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
|
||||||
int dimIdx = rank - 1 - reversedIdx;
|
|
||||||
sharedDimCount[dimIdx] = 0;
|
|
||||||
for (int i = 0; i < operands.size(); ++i) {
|
|
||||||
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
|
||||||
if (reversedIdx <= shape.size() - 1)
|
|
||||||
sharedDimCount[dimIdx]++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// An unknown dimension can have a value of 1 or N (N > 1).
|
|
||||||
// If its value is 1, it is broadcasted dimension.
|
|
||||||
// Otherwise, non-broadcasted dimension.
|
|
||||||
// We only care about unknown dimensions whose number of sharing operands is
|
|
||||||
// more than one, since they are potentially broadcasted dimensions.
|
|
||||||
for (int i = 0; i < operands.size(); ++i) {
|
|
||||||
std::map<int, Value> broadcastedDims;
|
|
||||||
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
|
||||||
int size = shape.size();
|
|
||||||
for (int j = 0; j < shape.size(); ++j) {
|
|
||||||
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
|
|
||||||
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
|
|
||||||
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
|
|
||||||
auto isBroadcasted =
|
|
||||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
|
||||||
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DimInfo.insert(std::make_pair(i, broadcastedDims));
|
|
||||||
}
|
|
||||||
return DimInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract induction variables that are used for broadcasting values of a
|
|
||||||
// given operand.
|
|
||||||
std::vector<Value>
|
|
||||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
|
||||||
ArrayRef<Value> loopIVs, Value operand,
|
|
||||||
std::map<int, Value> broadcastedDims) {
|
|
||||||
// `operand` must has a ranked type. This should have been checked by the
|
|
||||||
// shape inference pass.
|
|
||||||
auto operandShape = operand.getType().cast<MemRefType>().getShape();
|
|
||||||
auto rank = operandShape.size();
|
|
||||||
auto loopCount = loopIVs.size();
|
|
||||||
|
|
||||||
std::vector<Value> newLoopIVs;
|
|
||||||
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
|
||||||
auto dimIdx = rank - 1 - reversedIdx;
|
|
||||||
auto loopIdx = loopCount - 1 - reversedIdx;
|
|
||||||
if (operandShape[dimIdx] == 1) {
|
|
||||||
// Broadcasted dimension
|
|
||||||
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
|
||||||
newLoopIVs.insert(newLoopIVs.begin(), zero);
|
|
||||||
} else if ((operandShape[dimIdx] == -1) &&
|
|
||||||
(broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
|
|
||||||
// Unknown dimension, it can have a value of 1 or N (N > 1).
|
|
||||||
// If its value is 1, it is broadcasted dimension.
|
|
||||||
// Otherwise, non-broadcasted dimension.
|
|
||||||
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
|
||||||
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
|
|
||||||
loopIVs[loopIdx]);
|
|
||||||
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
|
||||||
} else {
|
|
||||||
// Non-broadcasted dimension
|
|
||||||
newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return newLoopIVs;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// This is to get a scalar operation of a given type for a specific operation.
|
|
||||||
template <typename Op>
|
|
||||||
struct ScalarOp {
|
|
||||||
using FOp = void;
|
|
||||||
using IOp = void;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename FOp>
|
|
||||||
using ScalarFOp = typename ScalarOp<FOp>::FOp;
|
|
||||||
template <typename IOp>
|
|
||||||
using ScalarIOp = typename ScalarOp<IOp>::IOp;
|
|
||||||
|
|
||||||
// Get the identity element of a operation.
|
|
||||||
// Return NULL if the function does not have identity.
|
|
||||||
template <typename DataType, typename Op>
|
|
||||||
DataType getIdentityValue() {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// This is used in the innermost loop of a KrnlIterateOp to insert computation
|
|
||||||
// composed of one or many scalar ops.
|
|
||||||
// Use template specialization for each of different ONNX operations.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
template <typename Op>
|
|
||||||
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) {
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
Type element_type = operands.front().getType();
|
|
||||||
if (element_type.isa<IntegerType>()) {
|
|
||||||
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
|
|
||||||
mlir::None);
|
|
||||||
} else if (element_type.isa<FloatType>()) {
|
|
||||||
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
|
|
||||||
mlir::None);
|
|
||||||
} else {
|
|
||||||
emitError(loc, "unsupported element type");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We divide the operator lowering into different categories.
|
|
||||||
// These categories are mostly similar to the operator categories in ONNX:
|
|
||||||
// https://github.com/onnx/onnx/tree/master/onnx/defs.
|
|
||||||
// Besides, it is better to put operators with the same computation pattern into
|
|
||||||
// the same category, e.g. element-wise operators will belong to the elementwise
|
|
||||||
// category.
|
|
||||||
|
|
||||||
// Math
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc"
|
|
||||||
// Tensor
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
|
|
||||||
// Neural network
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc"
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// EntryPoint Op lowering to Krnl Entry Point.
|
// EntryPoint Op lowering to Krnl Entry Point.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -427,39 +34,6 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
struct TensorTypeConverter : public TypeConverter {
|
|
||||||
using TypeConverter::TypeConverter;
|
|
||||||
|
|
||||||
TensorTypeConverter() {
|
|
||||||
addConversion(convertType);
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
|
||||||
if (auto type = convertToMemRefType(t)) {
|
|
||||||
results.push_back(type);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
results.push_back(t);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return true if the inputs and outputs of the given function type are
|
|
||||||
/// legal. [Taken from MLIR and adapted to only check the legality of the
|
|
||||||
/// inputs. Once unranked results can be handled gracefully this
|
|
||||||
/// override needs to be removed in favour of the original MLIR one.]
|
|
||||||
bool isSignatureLegal(FunctionType funcType) {
|
|
||||||
return llvm::all_of(funcType.getInputs(),
|
|
||||||
[this](Type type) { return isLegal(type); });
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // end anonymous namespace.
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Frontend to Krnl Dialect lowering pass
|
// Frontend to Krnl Dialect lowering pass
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- elementwise.inc - Elementwise Ops ------------------------------===//
|
//===----- elementwise.cpp - Elementwise Ops ------------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<ONNXAddOp> {
|
struct ScalarOp<ONNXAddOp> {
|
||||||
using FOp = AddFOp;
|
using FOp = AddFOp;
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- gemm.inc - Lowering Gemm Op ------------------------------------===//
|
//===----- gemm.cpp - Lowering Gemm Op ------------------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
template <typename GemmOp>
|
template <typename GemmOp>
|
||||||
struct ONNXGemmOpLowering : public ConversionPattern {
|
struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
ONNXGemmOpLowering(MLIRContext *ctx)
|
ONNXGemmOpLowering(MLIRContext *ctx)
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- matmul.inc - Lowering Matmul Op --------------------------------===//
|
//===----- matmul.cpp - Lowering Matmul Op --------------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXMatMulOpLowering : public ConversionPattern {
|
struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
ONNXMatMulOpLowering(MLIRContext *ctx)
|
ONNXMatMulOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- reduction.inc - Lowering Reduction Ops -------------------------===//
|
//===----- reduction.cpp - Lowering Reduction Ops -------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
// Identity values
|
// Identity values
|
||||||
template <>
|
template <>
|
||||||
float getIdentityValue<float, ONNXReduceMaxOp>(){
|
float getIdentityValue<float, ONNXReduceMaxOp>(){
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- softmax.inc - Softmax Op ---------------------------------------===//
|
//===----- softmax.cpp - Softmax Op ---------------------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- conv.inc - Lowering Convolution Op -----------------------------===//
|
//===----- conv.cpp - Lowering Convolution Op -----------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
|
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- normalization.inc - Lowering Normalization Ops -----------------===//
|
//===----- normalization.cpp - Lowering Normalization Ops -----------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
||||||
ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
|
ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(
|
: ConversionPattern(
|
|
@ -0,0 +1,324 @@
|
||||||
|
//====-- onnx_to_krnl_common.cpp - ONNX dialects to Krnl lowering ---------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains common code shared by the functions performing the
|
||||||
|
// lowering to the KRNL dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
/// Check is all dimensions are known at compile time.
|
||||||
|
bool hasAllConstantDimensions(MemRefType type) {
|
||||||
|
auto memRefShape = type.getShape();
|
||||||
|
for (int i = 0; i < memRefShape.size(); ++i)
|
||||||
|
if (memRefShape[i] < 0)
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||||
|
MemRefType convertToMemRefType(Type type) {
|
||||||
|
MemRefType memRefType;
|
||||||
|
auto tensorType = type.dyn_cast<TensorType>();
|
||||||
|
if (tensorType) {
|
||||||
|
assert(tensorType.hasRank() && "expected only ranked shapes");
|
||||||
|
memRefType =
|
||||||
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
|
} else {
|
||||||
|
memRefType = type.dyn_cast<MemRefType>();
|
||||||
|
}
|
||||||
|
return memRefType;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
|
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
|
PatternRewriter &rewriter,
|
||||||
|
bool insertDealloc,
|
||||||
|
ArrayRef<Value> operands) {
|
||||||
|
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||||
|
AllocOp alloc;
|
||||||
|
if (!operands.empty()) {
|
||||||
|
auto memRefShape = type.getShape();
|
||||||
|
auto rank = memRefShape.size();
|
||||||
|
|
||||||
|
std::map<int, Value> fromOperands;
|
||||||
|
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
|
int memRefDimIdx = rank - 1 - reversedIdx;
|
||||||
|
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
|
||||||
|
Value maxDim = nullptr;
|
||||||
|
for (int i = 0; i < operands.size(); i++) {
|
||||||
|
auto operandShape =
|
||||||
|
operands[i].getType().cast<MemRefType>().getShape();
|
||||||
|
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
|
||||||
|
|
||||||
|
if (operandDimIdx < 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// In case of operations with broadcasting, the dimension of the
|
||||||
|
// alloc result is the maximum size along each dimension of the
|
||||||
|
// operands.
|
||||||
|
auto operandDim =
|
||||||
|
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
|
||||||
|
if (maxDim) {
|
||||||
|
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
|
||||||
|
operandDim, maxDim);
|
||||||
|
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
|
||||||
|
maxDim);
|
||||||
|
} else {
|
||||||
|
maxDim = operandDim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> allocOperands;
|
||||||
|
for (int i = 0; i < rank; ++i)
|
||||||
|
if (memRefShape[i] < 0)
|
||||||
|
allocOperands.push_back(fromOperands[i]);
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
|
||||||
|
} else {
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure to allocate at the beginning of the block if
|
||||||
|
// all dimensions are known.
|
||||||
|
auto *parentBlock = alloc.getOperation()->getBlock();
|
||||||
|
if (hasAllConstantDimensions(type))
|
||||||
|
alloc.getOperation()->moveBefore(&parentBlock->front());
|
||||||
|
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
|
||||||
|
return alloc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine if current function returns the result value of the
|
||||||
|
// current op being lowered. If it does then dealloc should not be
|
||||||
|
// inserted.
|
||||||
|
bool checkInsertDealloc(Operation *currentOp) {
|
||||||
|
auto parentBlock = currentOp->getBlock();
|
||||||
|
|
||||||
|
bool insertDealloc = true;
|
||||||
|
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
||||||
|
assert(currentOp->getNumResults() < 2 &&
|
||||||
|
"No more than one result supported (for now).");
|
||||||
|
// If there is at least one result to investigate.
|
||||||
|
if (currentOp->getNumResults() > 0) {
|
||||||
|
auto result = currentOp->getResult(0);
|
||||||
|
for (const auto &operand : op.getOperands())
|
||||||
|
if (operand == result)
|
||||||
|
insertDealloc = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return insertDealloc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||||
|
// given that the result type is the result of a reduction op over the input
|
||||||
|
// type.
|
||||||
|
std::map<int64_t, int64_t>
|
||||||
|
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
||||||
|
std::map<int64_t, int64_t> OutInDimMap;
|
||||||
|
int64_t rank = inputTy.getRank();
|
||||||
|
|
||||||
|
// Mark reduction axes.
|
||||||
|
std::vector<bool> isReductionAxis;
|
||||||
|
for (decltype(rank) i = 0; i < rank; ++i) {
|
||||||
|
if (std::find(axes.begin(), axes.end(), i) != axes.end())
|
||||||
|
isReductionAxis.push_back(true);
|
||||||
|
else
|
||||||
|
isReductionAxis.push_back(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
|
||||||
|
// If it is a reduction axis, there is no relationship among dimensions.
|
||||||
|
if (isReductionAxis[inIndex]) {
|
||||||
|
if (keepdims)
|
||||||
|
outIndex++;
|
||||||
|
} else {
|
||||||
|
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
|
||||||
|
outIndex++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return OutInDimMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||||
|
// Dynamic dimenions are supported.
|
||||||
|
void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, KrnlIterateOperandPack &pack,
|
||||||
|
Value operand, int index) {
|
||||||
|
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||||
|
if (shape[index] < 0) {
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushOperandBound(
|
||||||
|
rewriter.create<DimOp>(loc, operand, index).getResult());
|
||||||
|
} else {
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushConstantBound(shape[index]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function that defines the KRNL dialect loops and their respective
|
||||||
|
// optimized version.
|
||||||
|
KrnlOptimizeLoopsOp
|
||||||
|
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
std::vector<Value> &loops,
|
||||||
|
std::vector<Value> &optimizedLoops, int64_t numLoops) {
|
||||||
|
// Define loops.
|
||||||
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
|
||||||
|
loops.reserve(numLoops);
|
||||||
|
for (auto result : loopsOp.getResults())
|
||||||
|
loops.push_back(result);
|
||||||
|
|
||||||
|
// Define optimized version of the loops.
|
||||||
|
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
|
||||||
|
optimizedLoops.reserve(numLoops);
|
||||||
|
for (auto result : optimizedLoopsOp.getResults())
|
||||||
|
optimizedLoops.push_back(result);
|
||||||
|
|
||||||
|
return optimizedLoopsOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function that emits the loops and their optimized version.
|
||||||
|
// The function returns a reference to the inner optimization block.
|
||||||
|
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
std::vector<Value> &loops,
|
||||||
|
std::vector<Value> &optimizedLoops,
|
||||||
|
int64_t numLoops) {
|
||||||
|
KrnlOptimizeLoopsOp optimizedLoopsOp =
|
||||||
|
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
|
||||||
|
return &optimizedLoopsOp.region().front();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function which emits a basic set of loops and optimized loops
|
||||||
|
// for a given operation argument. A reference to the loop optimization
|
||||||
|
// block is returned in the last argument of the function.
|
||||||
|
void emitKrnlLoopsAndIterationForOperand(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
||||||
|
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
||||||
|
KrnlIterateOp &iterateOp) {
|
||||||
|
// Operand shape.
|
||||||
|
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||||
|
|
||||||
|
// Number of loops.
|
||||||
|
int64_t rank = shape.size();
|
||||||
|
|
||||||
|
// Define loops and optimized loops.
|
||||||
|
std::vector<Value> optimizedLoops;
|
||||||
|
optimizedLoopsOp =
|
||||||
|
emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
||||||
|
|
||||||
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
|
// Iterate over the loop nest.
|
||||||
|
for (int i = 0; i < rank; ++i)
|
||||||
|
addDimensionToPack(rewriter, loc, pack, operand, i);
|
||||||
|
|
||||||
|
iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
||||||
|
auto elementType = memRefType.getElementType();
|
||||||
|
|
||||||
|
unsigned sizeInBits;
|
||||||
|
if (elementType.isIntOrFloat()) {
|
||||||
|
sizeInBits = elementType.getIntOrFloatBitWidth();
|
||||||
|
} else {
|
||||||
|
auto vectorType = elementType.cast<VectorType>();
|
||||||
|
sizeInBits =
|
||||||
|
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
|
||||||
|
}
|
||||||
|
return llvm::divideCeil(sizeInBits, 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
|
// broadcasting.
|
||||||
|
std::map<int, std::map<int, Value>>
|
||||||
|
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
MemRefType memRefType, ArrayRef<Value> operands) {
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
int64_t rank = memRefShape.size();
|
||||||
|
// For unknown dimensions, we need to get dimension values at runtime in
|
||||||
|
// order to do broadcasting.
|
||||||
|
std::map<int, std::map<int, Value>> DimInfo;
|
||||||
|
// For each result dimension, compute the number of sharing operands.
|
||||||
|
// Sharing operands are operands sharing the same index (counting from the
|
||||||
|
// rightmost to the leftmost) for a given dimension.
|
||||||
|
std::map<int, int> sharedDimCount;
|
||||||
|
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
|
int dimIdx = rank - 1 - reversedIdx;
|
||||||
|
sharedDimCount[dimIdx] = 0;
|
||||||
|
for (int i = 0; i < operands.size(); ++i) {
|
||||||
|
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||||
|
if (reversedIdx <= shape.size() - 1)
|
||||||
|
sharedDimCount[dimIdx]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// An unknown dimension can have a value of 1 or N (N > 1).
|
||||||
|
// If its value is 1, it is broadcasted dimension.
|
||||||
|
// Otherwise, non-broadcasted dimension.
|
||||||
|
// We only care about unknown dimensions whose number of sharing operands is
|
||||||
|
// more than one, since they are potentially broadcasted dimensions.
|
||||||
|
for (int i = 0; i < operands.size(); ++i) {
|
||||||
|
std::map<int, Value> broadcastedDims;
|
||||||
|
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||||
|
int size = shape.size();
|
||||||
|
for (int j = 0; j < shape.size(); ++j) {
|
||||||
|
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
|
||||||
|
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||||
|
auto isBroadcasted =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
||||||
|
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DimInfo.insert(std::make_pair(i, broadcastedDims));
|
||||||
|
}
|
||||||
|
return DimInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract induction variables that are used for broadcasting values of a
|
||||||
|
// given operand.
|
||||||
|
std::vector<Value>
|
||||||
|
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
ArrayRef<Value> loopIVs, Value operand,
|
||||||
|
std::map<int, Value> broadcastedDims) {
|
||||||
|
// `operand` must has a ranked type. This should have been checked by the
|
||||||
|
// shape inference pass.
|
||||||
|
auto operandShape = operand.getType().cast<MemRefType>().getShape();
|
||||||
|
auto rank = operandShape.size();
|
||||||
|
auto loopCount = loopIVs.size();
|
||||||
|
|
||||||
|
std::vector<Value> newLoopIVs;
|
||||||
|
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
|
auto dimIdx = rank - 1 - reversedIdx;
|
||||||
|
auto loopIdx = loopCount - 1 - reversedIdx;
|
||||||
|
if (operandShape[dimIdx] == 1) {
|
||||||
|
// Broadcasted dimension
|
||||||
|
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
newLoopIVs.insert(newLoopIVs.begin(), zero);
|
||||||
|
} else if ((operandShape[dimIdx] == -1) &&
|
||||||
|
(broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
|
||||||
|
// Unknown dimension, it can have a value of 1 or N (N > 1).
|
||||||
|
// If its value is 1, it is broadcasted dimension.
|
||||||
|
// Otherwise, non-broadcasted dimension.
|
||||||
|
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
|
||||||
|
loopIVs[loopIdx]);
|
||||||
|
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
||||||
|
} else {
|
||||||
|
// Non-broadcasted dimension
|
||||||
|
newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newLoopIVs;
|
||||||
|
}
|
|
@ -0,0 +1,217 @@
|
||||||
|
//====-- onnx_to_krnl_common.hpp - ONNX dialects to Krnl lowering ---------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains common code shared by the functions performing the
|
||||||
|
// lowering to the KRNL dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "src/dialect/krnl/krnl_helper.hpp"
|
||||||
|
#include "src/dialect/krnl/krnl_ops.hpp"
|
||||||
|
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||||
|
#include "src/pass/passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Common functions used when lowering the ONNX frontend dialect to KRNL.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Check is all dimensions are known at compile time.
|
||||||
|
bool hasAllConstantDimensions(MemRefType type);
|
||||||
|
|
||||||
|
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||||
|
MemRefType convertToMemRefType(Type type);
|
||||||
|
|
||||||
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
|
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
|
PatternRewriter &rewriter,
|
||||||
|
bool insertDealloc,
|
||||||
|
ArrayRef<Value> operands = {});
|
||||||
|
|
||||||
|
// Determine if current function returns the result value of the
|
||||||
|
// current op being lowered. If it does then dealloc should not be
|
||||||
|
// inserted.
|
||||||
|
bool checkInsertDealloc(Operation *currentOp);
|
||||||
|
|
||||||
|
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||||
|
// given that the result type is the result of a reduction op over the input
|
||||||
|
// type.
|
||||||
|
std::map<int64_t, int64_t>
|
||||||
|
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
|
||||||
|
|
||||||
|
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||||
|
// Dynamic dimenions are supported.
|
||||||
|
void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, KrnlIterateOperandPack &pack,
|
||||||
|
Value operand, int index);
|
||||||
|
|
||||||
|
// Function that defines the KRNL dialect loops and their respective
|
||||||
|
// optimized version.
|
||||||
|
KrnlOptimizeLoopsOp
|
||||||
|
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
std::vector<Value> &loops,
|
||||||
|
std::vector<Value> &optimizedLoops, int64_t numLoops);
|
||||||
|
|
||||||
|
// Function that emits the loops and their optimized version.
|
||||||
|
// The function returns a reference to the inner optimization block.
|
||||||
|
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
std::vector<Value> &loops,
|
||||||
|
std::vector<Value> &optimizedLoops,
|
||||||
|
int64_t numLoops);
|
||||||
|
|
||||||
|
// Function which emits a basic set of loops and optimized loops
|
||||||
|
// for a given operation argument. A reference to the loop optimization
|
||||||
|
// block is returned in the last argument of the function.
|
||||||
|
void emitKrnlLoopsAndIterationForOperand(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
||||||
|
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
||||||
|
KrnlIterateOp &iterateOp);
|
||||||
|
|
||||||
|
unsigned getMemRefEltSizeInBytes(MemRefType memRefType);
|
||||||
|
|
||||||
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
|
// broadcasting.
|
||||||
|
std::map<int, std::map<int, Value>>
|
||||||
|
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
MemRefType memRefType, ArrayRef<Value> operands);
|
||||||
|
|
||||||
|
// Extract induction variables that are used for broadcasting values of a
|
||||||
|
// given operand.
|
||||||
|
std::vector<Value>
|
||||||
|
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
ArrayRef<Value> loopIVs, Value operand,
|
||||||
|
std::map<int, Value> broadcastedDims);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This is to get a scalar operation of a given type for a specific operation.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <typename Op>
|
||||||
|
struct ScalarOp {
|
||||||
|
using FOp = void;
|
||||||
|
using IOp = void;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FOp>
|
||||||
|
using ScalarFOp = typename ScalarOp<FOp>::FOp;
|
||||||
|
template <typename IOp>
|
||||||
|
using ScalarIOp = typename ScalarOp<IOp>::IOp;
|
||||||
|
|
||||||
|
// Get the identity element of a operation.
|
||||||
|
// Return NULL if the function does not have identity.
|
||||||
|
template <typename DataType, typename Op>
|
||||||
|
DataType getIdentityValue() {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This is used in the innermost loop of a KrnlIterateOp to insert computation
|
||||||
|
// composed of one or many scalar ops.
|
||||||
|
// Use template specialization for each of different ONNX operations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <typename Op>
|
||||||
|
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Type element_type = operands.front().getType();
|
||||||
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
|
||||||
|
mlir::None);
|
||||||
|
} else if (element_type.isa<FloatType>()) {
|
||||||
|
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
|
||||||
|
mlir::None);
|
||||||
|
} else {
|
||||||
|
emitError(loc, "unsupported element type");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct TensorTypeConverter : public TypeConverter {
|
||||||
|
using TypeConverter::TypeConverter;
|
||||||
|
|
||||||
|
TensorTypeConverter() {
|
||||||
|
addConversion(convertType);
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||||
|
if (auto type = convertToMemRefType(t)) {
|
||||||
|
results.push_back(type);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
results.push_back(t);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return true if the inputs and outputs of the given function type are
|
||||||
|
/// legal. [Taken from MLIR and adapted to only check the legality of the
|
||||||
|
/// inputs. Once unranked results can be handled gracefully this
|
||||||
|
/// override needs to be removed in favour of the original MLIR one.]
|
||||||
|
bool isSignatureLegal(FunctionType funcType) {
|
||||||
|
return llvm::all_of(funcType.getInputs(),
|
||||||
|
[this](Type type) { return isLegal(type); });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Functions to add lowering patterns for frontend operations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// `math` directory methods:
|
||||||
|
|
||||||
|
void populateLoweringONNXElementwiseOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
|
||||||
|
MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXMatMulOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXReductionOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXSoftmaxOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
// `nn` directory methods:
|
||||||
|
|
||||||
|
void populateLoweringONNXConvOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXNormalizationOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
// `tensor` directory methods:
|
||||||
|
|
||||||
|
void populateLoweringONNXUnsqueezeOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXTransposeOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXReshapeOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXIdentityOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- identity.inc - Lowering Identity Op ----------------------------===//
|
//===----- identity.cpp - Lowering Identity Op ----------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXIdentityOpLowering : public ConversionPattern {
|
struct ONNXIdentityOpLowering : public ConversionPattern {
|
||||||
ONNXIdentityOpLowering(MLIRContext *ctx)
|
ONNXIdentityOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- reshape.inc - Lowering Reshape Op ------------------------------===//
|
//===----- reshape.cpp - Lowering Reshape Op ------------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXReshapeOpLowering : public ConversionPattern {
|
struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
ONNXReshapeOpLowering(MLIRContext *ctx)
|
ONNXReshapeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- transpose.inc - Lowering Transpose Op --------------------------===//
|
//===----- transpose.cpp - Lowering Transpose Op --------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXTransposeOpLowering : public ConversionPattern {
|
struct ONNXTransposeOpLowering : public ConversionPattern {
|
||||||
ONNXTransposeOpLowering(MLIRContext *ctx)
|
ONNXTransposeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- unsqueeze.inc - Lowering Unsqueeze Op --------------------------===//
|
//===----- unsqueeze.cpp - Lowering Unsqueeze Op --------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||||
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
Loading…
Reference in New Issue