Merge branch 'master' into shapeinference-pad
This commit is contained in:
commit
907104d7e8
|
@ -38,7 +38,7 @@ jobs:
|
|||
- run:
|
||||
name: Run End-To-End Tests
|
||||
command: |
|
||||
sudo pip install -q onnx
|
||||
sudo pip install -q -e ./ONNF/third_party/onnx
|
||||
cd ONNF/build
|
||||
cmake --build . --target run-onnx-backend-test
|
||||
- run:
|
||||
|
|
|
@ -1558,33 +1558,6 @@ ONNX Gather operation
|
|||
|
||||
1. `output`: memref of any type values or tensor of any type values
|
||||
|
||||
### onnx.GemmNoBias (ONNXGemmNoBiasOp)
|
||||
ONNX general matrix multiply operation without bias.
|
||||
|
||||
#### Description:
|
||||
|
||||
|
||||
The "onnx.Gemm" generic matrix multiplication without bias.
|
||||
|
||||
|
||||
#### Operands:
|
||||
|
||||
1. `A`: memref of any type values or tensor of any type values
|
||||
1. `B`: memref of any type values or tensor of any type values
|
||||
|
||||
#### Attributes:
|
||||
|
||||
| Attribute | MLIR Type | Description |
|
||||
| :-------: | :-------: | ----------- |
|
||||
| `alpha` | `FloatAttr` | 32-bit float attribute attribute |
|
||||
| `beta` | `FloatAttr` | 32-bit float attribute attribute |
|
||||
| `transA` | `IntegerAttr` | 64-bit integer attribute attribute |
|
||||
| `transB` | `IntegerAttr` | 64-bit integer attribute attribute |
|
||||
|
||||
#### Results:
|
||||
|
||||
1. `o_Y`: memref of any type values or tensor of any type values
|
||||
|
||||
### onnx.Gemm (ONNXGemmOp)
|
||||
ONNX Gemm operation
|
||||
|
||||
|
|
|
@ -62,7 +62,21 @@ target_include_directories(onnf_shape_inference
|
|||
target_link_libraries(onnf_shape_inference ${MLIRLibs})
|
||||
add_dependencies(onnf_shape_inference gen_krnl_ops)
|
||||
|
||||
add_library(onnf_lower_frontend conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
|
||||
add_library(onnf_lower_frontend
|
||||
conversion/onnx_to_krnl/onnx_to_krnl_common.cpp
|
||||
conversion/onnx_to_krnl/onnx_to_krnl_common.hpp
|
||||
conversion/onnx_to_krnl/math/elementwise.cpp
|
||||
conversion/onnx_to_krnl/math/gemm.cpp
|
||||
conversion/onnx_to_krnl/math/matmul.cpp
|
||||
conversion/onnx_to_krnl/math/reduction.cpp
|
||||
conversion/onnx_to_krnl/math/softmax.cpp
|
||||
conversion/onnx_to_krnl/nn/conv.cpp
|
||||
conversion/onnx_to_krnl/nn/normalization.cpp
|
||||
conversion/onnx_to_krnl/tensor/identity.cpp
|
||||
conversion/onnx_to_krnl/tensor/reshape.cpp
|
||||
conversion/onnx_to_krnl/tensor/transpose.cpp
|
||||
conversion/onnx_to_krnl/tensor/unsqueeze.cpp
|
||||
conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
|
||||
target_include_directories(onnf_lower_frontend
|
||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||
${ONNF_SRC_ROOT})
|
||||
|
|
|
@ -189,8 +189,9 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
mlir::Type elementType =
|
||||
convertONNXTypeToMLIRType(input.type().tensor_type().elem_type());
|
||||
auto elementOnnxType =
|
||||
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
|
||||
mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
|
||||
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
||||
arg_types.emplace_back(
|
||||
mlir::RankedTensorType::get(tensor_dims, elementType));
|
||||
|
|
|
@ -8,404 +8,11 @@
|
|||
// Krnl IR and standard operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include <map>
|
||||
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
||||
#include "src/dialect/krnl/krnl_helper.hpp"
|
||||
#include "src/dialect/krnl/krnl_ops.hpp"
|
||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||
#include "src/pass/passes.hpp"
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FrontendToAffine RewritePatterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Check is all dimensions are known at compile time.
|
||||
static bool hasAllConstantDimensions(MemRefType type) {
|
||||
auto memRefShape = type.getShape();
|
||||
for (int i = 0; i < memRefShape.size(); ++i)
|
||||
if (memRefShape[i] < 0)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||
static MemRefType convertToMemRefType(Type type) {
|
||||
MemRefType memRefType;
|
||||
auto tensorType = type.dyn_cast<TensorType>();
|
||||
if (tensorType) {
|
||||
assert(tensorType.hasRank() && "expected only ranked shapes");
|
||||
memRefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
} else {
|
||||
memRefType = type.dyn_cast<MemRefType>();
|
||||
}
|
||||
return memRefType;
|
||||
}
|
||||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter &rewriter,
|
||||
bool insertDealloc,
|
||||
ArrayRef<Value> operands = {}) {
|
||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||
AllocOp alloc;
|
||||
if (!operands.empty()) {
|
||||
auto memRefShape = type.getShape();
|
||||
auto rank = memRefShape.size();
|
||||
|
||||
std::map<int, Value> fromOperands;
|
||||
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||
int memRefDimIdx = rank - 1 - reversedIdx;
|
||||
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
|
||||
Value maxDim = nullptr;
|
||||
for (int i = 0; i < operands.size(); i++) {
|
||||
auto operandShape =
|
||||
operands[i].getType().cast<MemRefType>().getShape();
|
||||
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
|
||||
|
||||
if (operandDimIdx < 0)
|
||||
continue;
|
||||
|
||||
// In case of operations with broadcasting, the dimension of the
|
||||
// alloc result is the maximum size along each dimension of the
|
||||
// operands.
|
||||
auto operandDim =
|
||||
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
|
||||
if (maxDim) {
|
||||
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
|
||||
operandDim, maxDim);
|
||||
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
|
||||
maxDim);
|
||||
} else {
|
||||
maxDim = operandDim;
|
||||
}
|
||||
}
|
||||
fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> allocOperands;
|
||||
for (int i = 0; i < rank; ++i)
|
||||
if (memRefShape[i] < 0)
|
||||
allocOperands.push_back(fromOperands[i]);
|
||||
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
|
||||
} else {
|
||||
alloc = rewriter.create<AllocOp>(loc, type);
|
||||
}
|
||||
|
||||
// Make sure to allocate at the beginning of the block if
|
||||
// all dimensions are known.
|
||||
auto *parentBlock = alloc.getOperation()->getBlock();
|
||||
if (hasAllConstantDimensions(type))
|
||||
alloc.getOperation()->moveBefore(&parentBlock->front());
|
||||
|
||||
if (insertDealloc) {
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||
}
|
||||
|
||||
return alloc;
|
||||
}
|
||||
|
||||
// Determine if current function returns the result value of the
|
||||
// current op being lowered. If it does then dealloc should not be
|
||||
// inserted.
|
||||
static bool checkInsertDealloc(Operation *currentOp) {
|
||||
auto parentBlock = currentOp->getBlock();
|
||||
|
||||
bool insertDealloc = true;
|
||||
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
||||
assert(currentOp->getNumResults() < 2 &&
|
||||
"No more than one result supported (for now).");
|
||||
// If there is at least one result to investigate.
|
||||
if (currentOp->getNumResults() > 0) {
|
||||
auto result = currentOp->getResult(0);
|
||||
for (const auto &operand : op.getOperands())
|
||||
if (operand == result)
|
||||
insertDealloc = false;
|
||||
}
|
||||
});
|
||||
|
||||
return insertDealloc;
|
||||
}
|
||||
|
||||
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||
// given that the result type is the result of a reduction op over the input
|
||||
// type.
|
||||
std::map<int64_t, int64_t>
|
||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
||||
std::map<int64_t, int64_t> OutInDimMap;
|
||||
int64_t rank = inputTy.getRank();
|
||||
|
||||
// Mark reduction axes.
|
||||
std::vector<bool> isReductionAxis;
|
||||
for (decltype(rank) i = 0; i < rank; ++i) {
|
||||
if (std::find(axes.begin(), axes.end(), i) != axes.end())
|
||||
isReductionAxis.push_back(true);
|
||||
else
|
||||
isReductionAxis.push_back(false);
|
||||
}
|
||||
|
||||
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
|
||||
// If it is a reduction axis, there is no relationship among dimensions.
|
||||
if (isReductionAxis[inIndex]) {
|
||||
if (keepdims)
|
||||
outIndex++;
|
||||
} else {
|
||||
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
|
||||
outIndex++;
|
||||
}
|
||||
}
|
||||
|
||||
return OutInDimMap;
|
||||
}
|
||||
|
||||
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||
// Dynamic dimenions are supported.
|
||||
static void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
||||
Location loc, KrnlIterateOperandPack &pack,
|
||||
Value operand, int index) {
|
||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||
if (shape[index] < 0) {
|
||||
pack.pushConstantBound(0);
|
||||
pack.pushOperandBound(
|
||||
rewriter.create<DimOp>(loc, operand, index).getResult());
|
||||
} else {
|
||||
pack.pushConstantBound(0);
|
||||
pack.pushConstantBound(shape[index]);
|
||||
}
|
||||
}
|
||||
|
||||
// Function that defines the KRNL dialect loops and their respective
|
||||
// optimized version.
|
||||
static KrnlOptimizeLoopsOp
|
||||
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||
std::vector<Value> &loops,
|
||||
std::vector<Value> &optimizedLoops, int64_t numLoops) {
|
||||
// Define loops.
|
||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
|
||||
loops.reserve(numLoops);
|
||||
for (auto result : loopsOp.getResults())
|
||||
loops.push_back(result);
|
||||
|
||||
// Define optimized version of the loops.
|
||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
|
||||
optimizedLoops.reserve(numLoops);
|
||||
for (auto result : optimizedLoopsOp.getResults())
|
||||
optimizedLoops.push_back(result);
|
||||
|
||||
return optimizedLoopsOp;
|
||||
}
|
||||
|
||||
// Function that emits the loops and their optimized version.
|
||||
// The function returns a reference to the inner optimization block.
|
||||
static Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||
std::vector<Value> &loops,
|
||||
std::vector<Value> &optimizedLoops,
|
||||
int64_t numLoops) {
|
||||
KrnlOptimizeLoopsOp optimizedLoopsOp =
|
||||
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
|
||||
return &optimizedLoopsOp.region().front();
|
||||
}
|
||||
|
||||
// Function which emits a basic set of loops and optimized loops
|
||||
// for a given operation argument. A reference to the loop optimization
|
||||
// block is returned in the last argument of the function.
|
||||
static void emitKrnlLoopsAndIterationForOperand(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
||||
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
||||
KrnlIterateOp &iterateOp) {
|
||||
// Operand shape.
|
||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||
|
||||
// Number of loops.
|
||||
int64_t rank = shape.size();
|
||||
|
||||
// Define loops and optimized loops.
|
||||
std::vector<Value> optimizedLoops;
|
||||
optimizedLoopsOp =
|
||||
emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
||||
|
||||
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||
// Iterate over the loop nest.
|
||||
for (int i = 0; i < rank; ++i)
|
||||
addDimensionToPack(rewriter, loc, pack, operand, i);
|
||||
|
||||
iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||
}
|
||||
|
||||
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
||||
auto elementType = memRefType.getElementType();
|
||||
|
||||
unsigned sizeInBits;
|
||||
if (elementType.isIntOrFloat()) {
|
||||
sizeInBits = elementType.getIntOrFloatBitWidth();
|
||||
} else {
|
||||
auto vectorType = elementType.cast<VectorType>();
|
||||
sizeInBits =
|
||||
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
|
||||
}
|
||||
return llvm::divideCeil(sizeInBits, 8);
|
||||
}
|
||||
|
||||
// Get run-time dimension information for unknown dimensions used for
|
||||
// broadcasting.
|
||||
std::map<int, std::map<int, Value>>
|
||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||
MemRefType memRefType, ArrayRef<Value> operands) {
|
||||
auto memRefShape = memRefType.getShape();
|
||||
int64_t rank = memRefShape.size();
|
||||
// For unknown dimensions, we need to get dimension values at runtime in
|
||||
// order to do broadcasting.
|
||||
std::map<int, std::map<int, Value>> DimInfo;
|
||||
// For each result dimension, compute the number of sharing operands.
|
||||
// Sharing operands are operands sharing the same index (counting from the
|
||||
// rightmost to the leftmost) for a given dimension.
|
||||
std::map<int, int> sharedDimCount;
|
||||
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||
int dimIdx = rank - 1 - reversedIdx;
|
||||
sharedDimCount[dimIdx] = 0;
|
||||
for (int i = 0; i < operands.size(); ++i) {
|
||||
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||
if (reversedIdx <= shape.size() - 1)
|
||||
sharedDimCount[dimIdx]++;
|
||||
}
|
||||
}
|
||||
// An unknown dimension can have a value of 1 or N (N > 1).
|
||||
// If its value is 1, it is broadcasted dimension.
|
||||
// Otherwise, non-broadcasted dimension.
|
||||
// We only care about unknown dimensions whose number of sharing operands is
|
||||
// more than one, since they are potentially broadcasted dimensions.
|
||||
for (int i = 0; i < operands.size(); ++i) {
|
||||
std::map<int, Value> broadcastedDims;
|
||||
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||
int size = shape.size();
|
||||
for (int j = 0; j < shape.size(); ++j) {
|
||||
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
|
||||
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
|
||||
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
auto isBroadcasted =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
||||
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
|
||||
}
|
||||
}
|
||||
DimInfo.insert(std::make_pair(i, broadcastedDims));
|
||||
}
|
||||
return DimInfo;
|
||||
}
|
||||
|
||||
// Extract induction variables that are used for broadcasting values of a
|
||||
// given operand.
|
||||
std::vector<Value>
|
||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<Value> loopIVs, Value operand,
|
||||
std::map<int, Value> broadcastedDims) {
|
||||
// `operand` must has a ranked type. This should have been checked by the
|
||||
// shape inference pass.
|
||||
auto operandShape = operand.getType().cast<MemRefType>().getShape();
|
||||
auto rank = operandShape.size();
|
||||
auto loopCount = loopIVs.size();
|
||||
|
||||
std::vector<Value> newLoopIVs;
|
||||
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||
auto dimIdx = rank - 1 - reversedIdx;
|
||||
auto loopIdx = loopCount - 1 - reversedIdx;
|
||||
if (operandShape[dimIdx] == 1) {
|
||||
// Broadcasted dimension
|
||||
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
newLoopIVs.insert(newLoopIVs.begin(), zero);
|
||||
} else if ((operandShape[dimIdx] == -1) &&
|
||||
(broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
|
||||
// Unknown dimension, it can have a value of 1 or N (N > 1).
|
||||
// If its value is 1, it is broadcasted dimension.
|
||||
// Otherwise, non-broadcasted dimension.
|
||||
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
|
||||
loopIVs[loopIdx]);
|
||||
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
||||
} else {
|
||||
// Non-broadcasted dimension
|
||||
newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
|
||||
}
|
||||
}
|
||||
return newLoopIVs;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// This is to get a scalar operation of a given type for a specific operation.
|
||||
template <typename Op>
|
||||
struct ScalarOp {
|
||||
using FOp = void;
|
||||
using IOp = void;
|
||||
};
|
||||
|
||||
template <typename FOp>
|
||||
using ScalarFOp = typename ScalarOp<FOp>::FOp;
|
||||
template <typename IOp>
|
||||
using ScalarIOp = typename ScalarOp<IOp>::IOp;
|
||||
|
||||
// Get the identity element of a operation.
|
||||
// Return NULL if the function does not have identity.
|
||||
template <typename DataType, typename Op>
|
||||
DataType getIdentityValue() {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This is used in the innermost loop of a KrnlIterateOp to insert computation
|
||||
// composed of one or many scalar ops.
|
||||
// Use template specialization for each of different ONNX operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <typename Op>
|
||||
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto loc = op->getLoc();
|
||||
Type element_type = operands.front().getType();
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
|
||||
mlir::None);
|
||||
} else if (element_type.isa<FloatType>()) {
|
||||
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
|
||||
mlir::None);
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// We divide the operator lowering into different categories.
|
||||
// These categories are mostly similar to the operator categories in ONNX:
|
||||
// https://github.com/onnx/onnx/tree/master/onnx/defs.
|
||||
// Besides, it is better to put operators with the same computation pattern into
|
||||
// the same category, e.g. element-wise operators will belong to the elementwise
|
||||
// category.
|
||||
|
||||
// Math
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc"
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc"
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc"
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc"
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc"
|
||||
// Tensor
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc"
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc"
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc"
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
|
||||
// Neural network
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EntryPoint Op lowering to Krnl Entry Point.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -427,39 +34,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct TensorTypeConverter : public TypeConverter {
|
||||
using TypeConverter::TypeConverter;
|
||||
|
||||
TensorTypeConverter() {
|
||||
addConversion(convertType);
|
||||
}
|
||||
|
||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||
if (auto type = convertToMemRefType(t)) {
|
||||
results.push_back(type);
|
||||
return success();
|
||||
}
|
||||
|
||||
results.push_back(t);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Return true if the inputs and outputs of the given function type are
|
||||
/// legal. [Taken from MLIR and adapted to only check the legality of the
|
||||
/// inputs. Once unranked results can be handled gracefully this
|
||||
/// override needs to be removed in favour of the original MLIR one.]
|
||||
bool isSignatureLegal(FunctionType funcType) {
|
||||
return llvm::all_of(funcType.getInputs(),
|
||||
[this](Type type) { return isLegal(type); });
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Frontend to Krnl Dialect lowering pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===----- elementwise.inc - Elementwise Ops ------------------------------===//
|
||||
//===----- elementwise.cpp - Elementwise Ops ------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -8,6 +8,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
template <>
|
||||
struct ScalarOp<ONNXAddOp> {
|
||||
using FOp = AddFOp;
|
|
@ -1,4 +1,4 @@
|
|||
//===----- gemm.inc - Lowering Gemm Op ------------------------------------===//
|
||||
//===----- gemm.cpp - Lowering Gemm Op ------------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -8,6 +8,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
template <typename GemmOp>
|
||||
struct ONNXGemmOpLowering : public ConversionPattern {
|
||||
ONNXGemmOpLowering(MLIRContext *ctx)
|
||||
|
@ -17,9 +21,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
// The first predicate is unnecessary when we remove ONXGemmNoBiasOp.
|
||||
bool hasBias = (operands.size() == 3) &&
|
||||
(!op->getOperand(2).getType().isa<NoneType>());
|
||||
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
|
||||
|
||||
Value A, B, C;
|
||||
A = operands[0];
|
||||
|
@ -215,5 +217,4 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx) {
|
||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
|
||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
//===----- matmul.inc - Lowering Matmul Op --------------------------------===//
|
||||
//===----- matmul.cpp - Lowering Matmul Op --------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -8,6 +8,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||
ONNXMatMulOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
|||
//===----- reduction.inc - Lowering Reduction Ops -------------------------===//
|
||||
//===----- reduction.cpp - Lowering Reduction Ops -------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -8,6 +8,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// Identity values
|
||||
template <>
|
||||
float getIdentityValue<float, ONNXReduceMaxOp>(){
|
|
@ -1,4 +1,4 @@
|
|||
//===----- softmax.inc - Softmax Op ---------------------------------------===//
|
||||
//===----- softmax.cpp - Softmax Op ---------------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -8,6 +8,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
|||
//===----- conv.inc - Lowering Convolution Op -----------------------------===//
|
||||
//===----- conv.cpp - Lowering Convolution Op -----------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -8,6 +8,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
|||
//===----- normalization.inc - Lowering Normalization Ops -----------------===//
|
||||
//===----- normalization.cpp - Lowering Normalization Ops -----------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -8,6 +8,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
||||
ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(
|
|
@ -0,0 +1,324 @@
|
|||
//====-- onnx_to_krnl_common.cpp - ONNX dialects to Krnl lowering ---------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains common code shared by the functions performing the
|
||||
// lowering to the KRNL dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
/// Check is all dimensions are known at compile time.
|
||||
bool hasAllConstantDimensions(MemRefType type) {
|
||||
auto memRefShape = type.getShape();
|
||||
for (int i = 0; i < memRefShape.size(); ++i)
|
||||
if (memRefShape[i] < 0)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||
MemRefType convertToMemRefType(Type type) {
|
||||
MemRefType memRefType;
|
||||
auto tensorType = type.dyn_cast<TensorType>();
|
||||
if (tensorType) {
|
||||
assert(tensorType.hasRank() && "expected only ranked shapes");
|
||||
memRefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
} else {
|
||||
memRefType = type.dyn_cast<MemRefType>();
|
||||
}
|
||||
return memRefType;
|
||||
}
|
||||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter &rewriter,
|
||||
bool insertDealloc,
|
||||
ArrayRef<Value> operands) {
|
||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||
AllocOp alloc;
|
||||
if (!operands.empty()) {
|
||||
auto memRefShape = type.getShape();
|
||||
auto rank = memRefShape.size();
|
||||
|
||||
std::map<int, Value> fromOperands;
|
||||
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||
int memRefDimIdx = rank - 1 - reversedIdx;
|
||||
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
|
||||
Value maxDim = nullptr;
|
||||
for (int i = 0; i < operands.size(); i++) {
|
||||
auto operandShape =
|
||||
operands[i].getType().cast<MemRefType>().getShape();
|
||||
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
|
||||
|
||||
if (operandDimIdx < 0)
|
||||
continue;
|
||||
|
||||
// In case of operations with broadcasting, the dimension of the
|
||||
// alloc result is the maximum size along each dimension of the
|
||||
// operands.
|
||||
auto operandDim =
|
||||
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
|
||||
if (maxDim) {
|
||||
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
|
||||
operandDim, maxDim);
|
||||
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
|
||||
maxDim);
|
||||
} else {
|
||||
maxDim = operandDim;
|
||||
}
|
||||
}
|
||||
fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> allocOperands;
|
||||
for (int i = 0; i < rank; ++i)
|
||||
if (memRefShape[i] < 0)
|
||||
allocOperands.push_back(fromOperands[i]);
|
||||
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
|
||||
} else {
|
||||
alloc = rewriter.create<AllocOp>(loc, type);
|
||||
}
|
||||
|
||||
// Make sure to allocate at the beginning of the block if
|
||||
// all dimensions are known.
|
||||
auto *parentBlock = alloc.getOperation()->getBlock();
|
||||
if (hasAllConstantDimensions(type))
|
||||
alloc.getOperation()->moveBefore(&parentBlock->front());
|
||||
|
||||
if (insertDealloc) {
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||
}
|
||||
|
||||
return alloc;
|
||||
}
|
||||
|
||||
// Determine if current function returns the result value of the
|
||||
// current op being lowered. If it does then dealloc should not be
|
||||
// inserted.
|
||||
bool checkInsertDealloc(Operation *currentOp) {
|
||||
auto parentBlock = currentOp->getBlock();
|
||||
|
||||
bool insertDealloc = true;
|
||||
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
||||
assert(currentOp->getNumResults() < 2 &&
|
||||
"No more than one result supported (for now).");
|
||||
// If there is at least one result to investigate.
|
||||
if (currentOp->getNumResults() > 0) {
|
||||
auto result = currentOp->getResult(0);
|
||||
for (const auto &operand : op.getOperands())
|
||||
if (operand == result)
|
||||
insertDealloc = false;
|
||||
}
|
||||
});
|
||||
|
||||
return insertDealloc;
|
||||
}
|
||||
|
||||
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||
// given that the result type is the result of a reduction op over the input
|
||||
// type.
|
||||
std::map<int64_t, int64_t>
|
||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
||||
std::map<int64_t, int64_t> OutInDimMap;
|
||||
int64_t rank = inputTy.getRank();
|
||||
|
||||
// Mark reduction axes.
|
||||
std::vector<bool> isReductionAxis;
|
||||
for (decltype(rank) i = 0; i < rank; ++i) {
|
||||
if (std::find(axes.begin(), axes.end(), i) != axes.end())
|
||||
isReductionAxis.push_back(true);
|
||||
else
|
||||
isReductionAxis.push_back(false);
|
||||
}
|
||||
|
||||
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
|
||||
// If it is a reduction axis, there is no relationship among dimensions.
|
||||
if (isReductionAxis[inIndex]) {
|
||||
if (keepdims)
|
||||
outIndex++;
|
||||
} else {
|
||||
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
|
||||
outIndex++;
|
||||
}
|
||||
}
|
||||
|
||||
return OutInDimMap;
|
||||
}
|
||||
|
||||
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||
// Dynamic dimenions are supported.
|
||||
void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
||||
Location loc, KrnlIterateOperandPack &pack,
|
||||
Value operand, int index) {
|
||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||
if (shape[index] < 0) {
|
||||
pack.pushConstantBound(0);
|
||||
pack.pushOperandBound(
|
||||
rewriter.create<DimOp>(loc, operand, index).getResult());
|
||||
} else {
|
||||
pack.pushConstantBound(0);
|
||||
pack.pushConstantBound(shape[index]);
|
||||
}
|
||||
}
|
||||
|
||||
// Function that defines the KRNL dialect loops and their respective
|
||||
// optimized version.
|
||||
KrnlOptimizeLoopsOp
|
||||
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||
std::vector<Value> &loops,
|
||||
std::vector<Value> &optimizedLoops, int64_t numLoops) {
|
||||
// Define loops.
|
||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
|
||||
loops.reserve(numLoops);
|
||||
for (auto result : loopsOp.getResults())
|
||||
loops.push_back(result);
|
||||
|
||||
// Define optimized version of the loops.
|
||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
|
||||
optimizedLoops.reserve(numLoops);
|
||||
for (auto result : optimizedLoopsOp.getResults())
|
||||
optimizedLoops.push_back(result);
|
||||
|
||||
return optimizedLoopsOp;
|
||||
}
|
||||
|
||||
// Function that emits the loops and their optimized version.
|
||||
// The function returns a reference to the inner optimization block.
|
||||
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||
std::vector<Value> &loops,
|
||||
std::vector<Value> &optimizedLoops,
|
||||
int64_t numLoops) {
|
||||
KrnlOptimizeLoopsOp optimizedLoopsOp =
|
||||
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
|
||||
return &optimizedLoopsOp.region().front();
|
||||
}
|
||||
|
||||
// Function which emits a basic set of loops and optimized loops
|
||||
// for a given operation argument. A reference to the loop optimization
|
||||
// block is returned in the last argument of the function.
|
||||
void emitKrnlLoopsAndIterationForOperand(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
||||
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
||||
KrnlIterateOp &iterateOp) {
|
||||
// Operand shape.
|
||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||
|
||||
// Number of loops.
|
||||
int64_t rank = shape.size();
|
||||
|
||||
// Define loops and optimized loops.
|
||||
std::vector<Value> optimizedLoops;
|
||||
optimizedLoopsOp =
|
||||
emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
||||
|
||||
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||
// Iterate over the loop nest.
|
||||
for (int i = 0; i < rank; ++i)
|
||||
addDimensionToPack(rewriter, loc, pack, operand, i);
|
||||
|
||||
iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||
}
|
||||
|
||||
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
||||
auto elementType = memRefType.getElementType();
|
||||
|
||||
unsigned sizeInBits;
|
||||
if (elementType.isIntOrFloat()) {
|
||||
sizeInBits = elementType.getIntOrFloatBitWidth();
|
||||
} else {
|
||||
auto vectorType = elementType.cast<VectorType>();
|
||||
sizeInBits =
|
||||
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
|
||||
}
|
||||
return llvm::divideCeil(sizeInBits, 8);
|
||||
}
|
||||
|
||||
// Get run-time dimension information for unknown dimensions used for
|
||||
// broadcasting.
|
||||
std::map<int, std::map<int, Value>>
|
||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||
MemRefType memRefType, ArrayRef<Value> operands) {
|
||||
auto memRefShape = memRefType.getShape();
|
||||
int64_t rank = memRefShape.size();
|
||||
// For unknown dimensions, we need to get dimension values at runtime in
|
||||
// order to do broadcasting.
|
||||
std::map<int, std::map<int, Value>> DimInfo;
|
||||
// For each result dimension, compute the number of sharing operands.
|
||||
// Sharing operands are operands sharing the same index (counting from the
|
||||
// rightmost to the leftmost) for a given dimension.
|
||||
std::map<int, int> sharedDimCount;
|
||||
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||
int dimIdx = rank - 1 - reversedIdx;
|
||||
sharedDimCount[dimIdx] = 0;
|
||||
for (int i = 0; i < operands.size(); ++i) {
|
||||
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||
if (reversedIdx <= shape.size() - 1)
|
||||
sharedDimCount[dimIdx]++;
|
||||
}
|
||||
}
|
||||
// An unknown dimension can have a value of 1 or N (N > 1).
|
||||
// If its value is 1, it is broadcasted dimension.
|
||||
// Otherwise, non-broadcasted dimension.
|
||||
// We only care about unknown dimensions whose number of sharing operands is
|
||||
// more than one, since they are potentially broadcasted dimensions.
|
||||
for (int i = 0; i < operands.size(); ++i) {
|
||||
std::map<int, Value> broadcastedDims;
|
||||
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||
int size = shape.size();
|
||||
for (int j = 0; j < shape.size(); ++j) {
|
||||
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
|
||||
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
|
||||
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
auto isBroadcasted =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
||||
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
|
||||
}
|
||||
}
|
||||
DimInfo.insert(std::make_pair(i, broadcastedDims));
|
||||
}
|
||||
return DimInfo;
|
||||
}
|
||||
|
||||
// Extract induction variables that are used for broadcasting values of a
|
||||
// given operand.
|
||||
std::vector<Value>
|
||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<Value> loopIVs, Value operand,
|
||||
std::map<int, Value> broadcastedDims) {
|
||||
// `operand` must has a ranked type. This should have been checked by the
|
||||
// shape inference pass.
|
||||
auto operandShape = operand.getType().cast<MemRefType>().getShape();
|
||||
auto rank = operandShape.size();
|
||||
auto loopCount = loopIVs.size();
|
||||
|
||||
std::vector<Value> newLoopIVs;
|
||||
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||
auto dimIdx = rank - 1 - reversedIdx;
|
||||
auto loopIdx = loopCount - 1 - reversedIdx;
|
||||
if (operandShape[dimIdx] == 1) {
|
||||
// Broadcasted dimension
|
||||
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
newLoopIVs.insert(newLoopIVs.begin(), zero);
|
||||
} else if ((operandShape[dimIdx] == -1) &&
|
||||
(broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
|
||||
// Unknown dimension, it can have a value of 1 or N (N > 1).
|
||||
// If its value is 1, it is broadcasted dimension.
|
||||
// Otherwise, non-broadcasted dimension.
|
||||
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
|
||||
loopIVs[loopIdx]);
|
||||
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
||||
} else {
|
||||
// Non-broadcasted dimension
|
||||
newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
|
||||
}
|
||||
}
|
||||
return newLoopIVs;
|
||||
}
|
|
@ -0,0 +1,217 @@
|
|||
//====-- onnx_to_krnl_common.hpp - ONNX dialects to Krnl lowering ---------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains common code shared by the functions performing the
|
||||
// lowering to the KRNL dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/dialect/krnl/krnl_helper.hpp"
|
||||
#include "src/dialect/krnl/krnl_ops.hpp"
|
||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||
#include "src/pass/passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common functions used when lowering the ONNX frontend dialect to KRNL.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Check is all dimensions are known at compile time.
|
||||
bool hasAllConstantDimensions(MemRefType type);
|
||||
|
||||
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||
MemRefType convertToMemRefType(Type type);
|
||||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter &rewriter,
|
||||
bool insertDealloc,
|
||||
ArrayRef<Value> operands = {});
|
||||
|
||||
// Determine if current function returns the result value of the
|
||||
// current op being lowered. If it does then dealloc should not be
|
||||
// inserted.
|
||||
bool checkInsertDealloc(Operation *currentOp);
|
||||
|
||||
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||
// given that the result type is the result of a reduction op over the input
|
||||
// type.
|
||||
std::map<int64_t, int64_t>
|
||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
|
||||
|
||||
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||
// Dynamic dimenions are supported.
|
||||
void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
||||
Location loc, KrnlIterateOperandPack &pack,
|
||||
Value operand, int index);
|
||||
|
||||
// Function that defines the KRNL dialect loops and their respective
|
||||
// optimized version.
|
||||
KrnlOptimizeLoopsOp
|
||||
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||
std::vector<Value> &loops,
|
||||
std::vector<Value> &optimizedLoops, int64_t numLoops);
|
||||
|
||||
// Function that emits the loops and their optimized version.
|
||||
// The function returns a reference to the inner optimization block.
|
||||
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||
std::vector<Value> &loops,
|
||||
std::vector<Value> &optimizedLoops,
|
||||
int64_t numLoops);
|
||||
|
||||
// Function which emits a basic set of loops and optimized loops
|
||||
// for a given operation argument. A reference to the loop optimization
|
||||
// block is returned in the last argument of the function.
|
||||
void emitKrnlLoopsAndIterationForOperand(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
||||
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
||||
KrnlIterateOp &iterateOp);
|
||||
|
||||
unsigned getMemRefEltSizeInBytes(MemRefType memRefType);
|
||||
|
||||
// Get run-time dimension information for unknown dimensions used for
|
||||
// broadcasting.
|
||||
std::map<int, std::map<int, Value>>
|
||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||
MemRefType memRefType, ArrayRef<Value> operands);
|
||||
|
||||
// Extract induction variables that are used for broadcasting values of a
|
||||
// given operand.
|
||||
std::vector<Value>
|
||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<Value> loopIVs, Value operand,
|
||||
std::map<int, Value> broadcastedDims);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This is to get a scalar operation of a given type for a specific operation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <typename Op>
|
||||
struct ScalarOp {
|
||||
using FOp = void;
|
||||
using IOp = void;
|
||||
};
|
||||
|
||||
template <typename FOp>
|
||||
using ScalarFOp = typename ScalarOp<FOp>::FOp;
|
||||
template <typename IOp>
|
||||
using ScalarIOp = typename ScalarOp<IOp>::IOp;
|
||||
|
||||
// Get the identity element of a operation.
|
||||
// Return NULL if the function does not have identity.
|
||||
template <typename DataType, typename Op>
|
||||
DataType getIdentityValue() {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This is used in the innermost loop of a KrnlIterateOp to insert computation
|
||||
// composed of one or many scalar ops.
|
||||
// Use template specialization for each of different ONNX operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <typename Op>
|
||||
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto loc = op->getLoc();
|
||||
Type element_type = operands.front().getType();
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
|
||||
mlir::None);
|
||||
} else if (element_type.isa<FloatType>()) {
|
||||
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
|
||||
mlir::None);
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct TensorTypeConverter : public TypeConverter {
|
||||
using TypeConverter::TypeConverter;
|
||||
|
||||
TensorTypeConverter() {
|
||||
addConversion(convertType);
|
||||
}
|
||||
|
||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||
if (auto type = convertToMemRefType(t)) {
|
||||
results.push_back(type);
|
||||
return success();
|
||||
}
|
||||
|
||||
results.push_back(t);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Return true if the inputs and outputs of the given function type are
|
||||
/// legal. [Taken from MLIR and adapted to only check the legality of the
|
||||
/// inputs. Once unranked results can be handled gracefully this
|
||||
/// override needs to be removed in favour of the original MLIR one.]
|
||||
bool isSignatureLegal(FunctionType funcType) {
|
||||
return llvm::all_of(funcType.getInputs(),
|
||||
[this](Type type) { return isLegal(type); });
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Functions to add lowering patterns for frontend operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// `math` directory methods:
|
||||
|
||||
void populateLoweringONNXElementwiseOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXMatMulOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXReductionOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXSoftmaxOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
// `nn` directory methods:
|
||||
|
||||
void populateLoweringONNXConvOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXNormalizationOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
// `tensor` directory methods:
|
||||
|
||||
void populateLoweringONNXUnsqueezeOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXTransposeOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXReshapeOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXIdentityOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
@ -1,4 +1,4 @@
|
|||
//===----- identity.inc - Lowering Identity Op ----------------------------===//
|
||||
//===----- identity.cpp - Lowering Identity Op ----------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -8,6 +8,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXIdentityOpLowering : public ConversionPattern {
|
||||
ONNXIdentityOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
|||
//===----- reshape.inc - Lowering Reshape Op ------------------------------===//
|
||||
//===----- reshape.cpp - Lowering Reshape Op ------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -8,6 +8,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||
ONNXReshapeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
|||
//===----- transpose.inc - Lowering Transpose Op --------------------------===//
|
||||
//===----- transpose.cpp - Lowering Transpose Op --------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -8,6 +8,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXTransposeOpLowering : public ConversionPattern {
|
||||
ONNXTransposeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
|||
//===----- unsqueeze.inc - Lowering Unsqueeze Op --------------------------===//
|
||||
//===----- unsqueeze.cpp - Lowering Unsqueeze Op --------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -8,6 +8,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
|
@ -131,7 +131,7 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
|
|||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||
}
|
||||
|
||||
void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
|
||||
void KrnlIterateOperandPack::pushOperandBound(Value operand) {
|
||||
if (boundMaps.size() % 2 == 0)
|
||||
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
||||
AffineMap map = builder.getSymbolIdentityMap();
|
||||
|
@ -145,7 +145,7 @@ BuildKrnlLoop::BuildKrnlLoop(
|
|||
pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
|
||||
createdIterateOp(false) {
|
||||
if (originalLoopNum <= 0)
|
||||
emitError(loc, "expected positive number of original loops");
|
||||
emitError(loc, "Expected positive number of original loops.");
|
||||
}
|
||||
|
||||
BuildKrnlLoop::BuildKrnlLoop(
|
||||
|
@ -154,25 +154,24 @@ BuildKrnlLoop::BuildKrnlLoop(
|
|||
memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
|
||||
|
||||
BuildKrnlLoop::~BuildKrnlLoop() {
|
||||
if (!createdDefineOp)
|
||||
emitError(loc, "expected to create define op");
|
||||
if (!createdIterateOp)
|
||||
emitError(loc, "expected to create iteration op");
|
||||
if (pack)
|
||||
free(pack);
|
||||
}
|
||||
|
||||
void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
|
||||
// insert define loop op
|
||||
// Insert define loop operation.
|
||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
|
||||
originalLoops.reserve(originalLoopNum);
|
||||
for (auto result : loopsOp.getResults())
|
||||
originalLoops.push_back(result);
|
||||
// inserte optimize loop op.
|
||||
createdDefineOp = true;
|
||||
|
||||
// Insert optimize loop operation.
|
||||
auto optimizedLoopsOp =
|
||||
rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
|
||||
optLoops.reserve(originalLoopNum);
|
||||
// Emit empty optimizations
|
||||
|
||||
// Emit empty optimizations if flag is set.
|
||||
if (withEmptyOptimization) {
|
||||
for (auto result : optimizedLoopsOp.getResults())
|
||||
optLoops.push_back(result);
|
||||
|
@ -182,12 +181,12 @@ void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
|
|||
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||
rewriter.restoreInsertionPoint(ip);
|
||||
}
|
||||
createdOptimizeOp = true;
|
||||
|
||||
// prepare data structure to push bounds
|
||||
pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
|
||||
createdOptimizeOp = true;
|
||||
}
|
||||
|
||||
// push bounds (lower and upper) and return index for loop info
|
||||
int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
|
||||
pack->pushConstantBound(lowerBound);
|
||||
pack->pushConstantBound(upperBound);
|
||||
|
@ -203,17 +202,20 @@ int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) {
|
|||
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
|
||||
int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
|
||||
pack->pushConstantBound(lowerBound);
|
||||
// process upperBound as a dimension of mem ref, possibly non-constant
|
||||
|
||||
// Process upperBound as a dimension of the MemRef. Non-constant dimensions
|
||||
// are supported.
|
||||
auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
|
||||
if (shape[upperBoundMemRefIndex] < 0) {
|
||||
if (upperBoundMustBeConstant)
|
||||
emitError(loc, "bound expected to be constant");
|
||||
emitError(loc, "Bound expected to be constant.");
|
||||
pack->pushOperandBound(
|
||||
rewriter
|
||||
.create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
|
||||
.getResult());
|
||||
} else
|
||||
pack->pushConstantBound(shape[upperBoundMemRefIndex]);
|
||||
|
||||
return pushCount++;
|
||||
}
|
||||
|
||||
|
@ -223,19 +225,20 @@ int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) {
|
|||
return pushCount++;
|
||||
}
|
||||
|
||||
// create iter
|
||||
void BuildKrnlLoop::createIterateOp() {
|
||||
// Loop definition operation is mandatory.
|
||||
if (!createdDefineOp)
|
||||
emitError(loc, "must create define op before iterate op");
|
||||
// Tight now, optimize (possibly empty) is mandatory. This may change
|
||||
emitError(loc, "Must create define op before iterate op.");
|
||||
|
||||
// Loop optimization operation is mandatory (for now).
|
||||
if (!createdOptimizeOp)
|
||||
emitError(loc, "must create optimize op before iterate op");
|
||||
// have to have defined all bounds
|
||||
if (pushCount != originalLoopNum) {
|
||||
printf(" push count %d, original loop %d\n", pushCount, originalLoopNum);
|
||||
emitError(loc, "must push bounds for all original loops");
|
||||
}
|
||||
// create iterate op
|
||||
emitError(loc, "Must create optimize op before iterate op.");
|
||||
|
||||
// Check if all bounds have been defined.
|
||||
if (pushCount != originalLoopNum)
|
||||
emitError(loc, "Must push bounds for all original loops.");
|
||||
|
||||
// Emit iteration operation.
|
||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
|
||||
iterBlock = &iterateOp.bodyRegion().front();
|
||||
createdIterateOp = true;
|
||||
|
@ -243,19 +246,27 @@ void BuildKrnlLoop::createIterateOp() {
|
|||
|
||||
void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
|
||||
Value memRefOperand, bool withEmptyOptimization) {
|
||||
// Rank of the MemRef operand. We will emit a loop for each dimension.
|
||||
int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
|
||||
if (originalLoopNum != loopNum)
|
||||
emitError(loc, "mismatch in loop numbers from constructor and define");
|
||||
emitError(loc, "Mismatch in loop numbers from constructor and define.");
|
||||
|
||||
// Emit the definition and the optimization operations for the loop nest.
|
||||
createDefineAndOptimizeOp(withEmptyOptimization);
|
||||
|
||||
// Push a lower-upper bound pair for each dimension of the MemRef operand.
|
||||
// The lower bound in this case is always zero.
|
||||
for (int i = 0; i < originalLoopNum; ++i)
|
||||
pushBounds(0, memRefOperand, i);
|
||||
|
||||
// Emit the iteration operation over the current loop nest.
|
||||
createIterateOp();
|
||||
}
|
||||
|
||||
// get induction variable to be use within iter
|
||||
BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
|
||||
// Check if loop iteration variable is within bounds.
|
||||
if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
|
||||
emitError(loc, "original loop index is out of bound");
|
||||
emitError(loc, "Original loop index is out of bounds.");
|
||||
return iterBlock->getArguments()[originalLoopIndex];
|
||||
}
|
||||
|
||||
|
|
|
@ -106,19 +106,21 @@ private:
|
|||
//
|
||||
// The sequence is as follow:
|
||||
//
|
||||
// 1) Create a object giving the rewriter, location, and number of loop in the
|
||||
// original (non optimized) loop.
|
||||
// 1) Create an object giving the rewriter, location, and number of loop in
|
||||
// the original (non optimized) loop.
|
||||
//
|
||||
// 2) Create define & optimize ops (currently paired). Optimizations can then
|
||||
// be added to the inner block of the optimize operation. Make sure to set the
|
||||
// insertion point to that block for optimizations to go in the right place.
|
||||
// be added to the inner block of the optimize operation. Make sure to set
|
||||
// the insertion point to that block for optimizations to go in the right
|
||||
// place.
|
||||
//
|
||||
// 3) Push the bounds for each of the original loops. Bounds are pushed in
|
||||
// pairs (lower & upper bounds). THere are a few methods to do it depending on
|
||||
// the type of the bounds. When pushing bounds, the method returns a number
|
||||
// that represent the index associated with that iteration (induction variable
|
||||
// and bounds). That index can be used later to extract the induction variable
|
||||
// for reference in computation and/or index calculations of mem refs.
|
||||
// pairs (lower & upper bounds). There are a few methods to do it depending
|
||||
// on the type of the bounds. When pushing bounds, the method returns a
|
||||
// number that represent the index associated with that iteration (induction
|
||||
// variable and bounds). That index can be used later to extract the
|
||||
// induction variable for reference in computation and/or index calculations
|
||||
// of mem refs.
|
||||
//
|
||||
// 4) Once all the bounds are pushed, create the iterate operation. Once this
|
||||
// is done, we can add operations within the iterate blocks by setting the
|
||||
|
@ -127,67 +129,90 @@ private:
|
|||
|
||||
class BuildKrnlLoop {
|
||||
public:
|
||||
// Create a build kernel loop for the given location and loop number.
|
||||
// Create kernel loop builder for a loop nest of depth loopNum.
|
||||
BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
|
||||
// Do the same, but where the loop number corresponds to the dimensionality of
|
||||
// the mem ref operand.
|
||||
|
||||
// Create kernel loop builder for a loop nest of depth equal to the
|
||||
// dimensionality of the operand. An operand of MemRef type is requied.
|
||||
BuildKrnlLoop(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
|
||||
~BuildKrnlLoop();
|
||||
|
||||
// Create define and optimize loop with loopNum original loops. If
|
||||
// withEmptyOptimization, the optimization is simply the identity function (no
|
||||
// optimizations).
|
||||
// withEmptyOptimization is true, the optimization is simply the identity
|
||||
// function (no optimizations).
|
||||
void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
|
||||
|
||||
// Push bounds (lower and upper) for each of the loops, in order. It returns
|
||||
// the index associated with the loop iteration. This index is in the range
|
||||
// from zero to original loop number -1, and is monotonally increasing from
|
||||
// call to call. This index is later used in the getInductionVar call.
|
||||
// Push bounds (lower and upper) for each of the loops (order matters).
|
||||
// The function returns the order number associated with the loop iteration.
|
||||
// This index is used by the getInductionVar call. Non-constant operands
|
||||
// must be of MemRef type.
|
||||
int pushBounds(int64_t lowerBound, int64_t upperBound);
|
||||
int pushBounds(int64_t lowerBound, Value upperBound);
|
||||
int pushBounds(Value lowerBound, Value upperBound);
|
||||
// same, where the lower bound is an integer, and the uppoer bound is given by
|
||||
// the size of the mem ref operand along the upperBoundMemRefIndex dimension.
|
||||
int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
|
||||
int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
|
||||
|
||||
// Create an iterate op.
|
||||
// Create the KrnlIterateOp assiciated with this loop nest. The loops
|
||||
// iteration will be created if the definition and the optimization
|
||||
// operations associated with this loop nest have been emitted already.
|
||||
void createIterateOp();
|
||||
// Create an define, optimize and iterate op, with the same loop nummber as
|
||||
// the rank of the memRefOperand. The lower bound of each loops is zero, and
|
||||
// the upper bound of each loops is the dimension given by the mem refs
|
||||
|
||||
// Create the loop nest definition, optimization and iteration operations
|
||||
// for a given operand of MemRef type. The loop nest has a depth equal to the
|
||||
// rank of the MemRef operand. The lower bound of each loop is zero. The
|
||||
// upper bound of each loop is given by the corresponding dimension of the
|
||||
// MemRef operand.
|
||||
void createDefineOptimizeAndIterateOp(
|
||||
Value memRefOperand, bool withEmptyOptimization = true);
|
||||
|
||||
// Get the (original loop) induction variable associated with the given index.
|
||||
// Use the index returned when pushing the bounds.
|
||||
// Get the (original loop) induction variable associated with the given
|
||||
// index. Use the index returned when pushing the bounds.
|
||||
BlockArgument &getInductionVar(int originalLoopIndex);
|
||||
|
||||
// Get blocks. This allow us to set the insertion point to the inner block of
|
||||
// the optimize and the iterate Operation
|
||||
// Get a reference to the code region of the optimization operation.
|
||||
// This allows us to set the insertion point to the inner block of the
|
||||
// loop nest optimization operation.
|
||||
Block *getOptimizationBlock() { return optBlock; }
|
||||
|
||||
// Get a reference to the code region of the iteration operation.
|
||||
// This allows us to set the insertion point to the inner block of the
|
||||
// loop nest iteration operation.
|
||||
Block *getIterateBlock() { return iterBlock; }
|
||||
|
||||
// get original or optimized loops
|
||||
// Get original loop nest.
|
||||
std::vector<Value> &getOriginalLoops() { return originalLoops; }
|
||||
|
||||
// Get optimized loop nest.
|
||||
std::vector<Value> &getOptimizedLoops() { return optLoops; }
|
||||
|
||||
private:
|
||||
// inputs
|
||||
// Required for emitting operations.
|
||||
ConversionPatternRewriter &rewriter;
|
||||
Location loc;
|
||||
int originalLoopNum;
|
||||
// track loops and bounds
|
||||
|
||||
// List of original, un-optimized loops.
|
||||
std::vector<Value> originalLoops;
|
||||
|
||||
// List of optimized loops.
|
||||
std::vector<Value> optLoops;
|
||||
|
||||
// List of lower-upper bound pairs needed by the KrnlIterateOp.
|
||||
KrnlIterateOperandPack *pack;
|
||||
|
||||
// Number of lower-upper bound pairs pushed.
|
||||
int pushCount;
|
||||
|
||||
// Flags that keep track of emitted operations.
|
||||
bool createdDefineOp;
|
||||
bool createdOptimizeOp;
|
||||
bool createdIterateOp;
|
||||
// insertion points (opt block, iterate)
|
||||
|
||||
// Saved insertion point in the code region of the KrnlOptimizeLoopsOp.
|
||||
Block *optBlock;
|
||||
|
||||
// Saved insertion point in the code region of the KrnlIterateOp.
|
||||
Block *iterBlock;
|
||||
};
|
||||
|
||||
|
|
|
@ -90,25 +90,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
|||
// or outputs. This decision affects only ONNX operations with optional
|
||||
// arguments not ONNX operations with variadic operands.
|
||||
|
||||
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX general matrix multiply operation without bias.";
|
||||
let description = [{
|
||||
|
||||
The "onnx.Gemm" generic matrix multiplication without bias.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
|
||||
DefaultValuedAttr<F32Attr, "1.0">:$alpha,
|
||||
DefaultValuedAttr<F32Attr, "1.0">:$beta,
|
||||
DefaultValuedAttr<I64Attr, "0">:$transA,
|
||||
DefaultValuedAttr<I64Attr, "0">:$transB);
|
||||
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||
}
|
||||
|
||||
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let hasCanonicalizer = 1;
|
||||
|
|
|
@ -565,32 +565,6 @@ void ONNXGemmOp::inferShapes() {
|
|||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
// GemmNoBias
|
||||
|
||||
void ONNXGemmNoBiasOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
|
||||
int64_t M, N, K_A, K_B;
|
||||
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
||||
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
|
||||
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
|
||||
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
||||
|
||||
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
||||
emitError("Tensor shapes mismatched.");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
dims.emplace_back(M);
|
||||
dims.emplace_back(N);
|
||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
/// BatchNormalizationTestMode
|
||||
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
|
|
|
@ -118,7 +118,6 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Identity" &&
|
||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||
op->getName().getStringRef() != "onnx.GemmNoBias" &&
|
||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||
op->getName().getStringRef() != "onnx.ReduceMax" &&
|
||||
|
|
|
@ -806,35 +806,6 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso
|
|||
// CHECK: }
|
||||
}
|
||||
|
||||
func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_gemm_no_bias
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
|
||||
// CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
|
||||
// CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) {
|
||||
// CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32>
|
||||
// CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
|
||||
// CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
|
||||
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
|
||||
// CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
|
||||
// CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 1439eab5542c625bb3da49860f0cd68c3eafdc18
|
||||
Subproject commit 553df22c67bee5f0fe6599cff60f1afc6748c635
|
Loading…
Reference in New Issue