[NFC] Categorize ONNX ops lowering (#80)
* Create two categories: elementwise and tensor * typos * Create directories for categories * Edit comments * Extract a function that creates a KrnlIterateOp * Add comments * Extract some common parts * Revise softmax * Add reduction.inc * Move lower-frontend to lib/conversion * Move directory to directory * Change file/directory names * Comment format * Add matmul.inc
This commit is contained in:
parent
3c505ae31d
commit
b9f2f25b56
|
@ -57,7 +57,7 @@ target_include_directories(onnf_shape_inference
|
||||||
target_link_libraries(onnf_shape_inference ${MLIRLibs})
|
target_link_libraries(onnf_shape_inference ${MLIRLibs})
|
||||||
add_dependencies(onnf_shape_inference gen_krnl_ops)
|
add_dependencies(onnf_shape_inference gen_krnl_ops)
|
||||||
|
|
||||||
add_library(onnf_lower_frontend pass/lower_frontend_to_krnl.cpp)
|
add_library(onnf_lower_frontend conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
|
||||||
target_include_directories(onnf_lower_frontend
|
target_include_directories(onnf_lower_frontend
|
||||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
${ONNF_SRC_ROOT})
|
${ONNF_SRC_ROOT})
|
||||||
|
|
|
@ -0,0 +1,529 @@
|
||||||
|
//====- convert_onnx_to_krnl.cpp - ONNX dialects to Krnl lowering ---------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file implements the lowering of frontend operations to a combination of
|
||||||
|
// Krnl IR and standard operations.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
|
||||||
|
#include "src/dialect/krnl/krnl_helper.hpp"
|
||||||
|
#include "src/dialect/krnl/krnl_ops.hpp"
|
||||||
|
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||||
|
#include "src/pass/passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// FrontendToAffine RewritePatterns
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Check is all dimensions are known at compile time.
|
||||||
|
static bool hasAllConstantDimensions(MemRefType type) {
|
||||||
|
auto memRefShape = type.getShape();
|
||||||
|
for (int i = 0; i < memRefShape.size(); ++i)
|
||||||
|
if (memRefShape[i] < 0)
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert the given TensorType into the corresponding MemRefType.
|
||||||
|
static MemRefType convertTensorToMemRef(TensorType type) {
|
||||||
|
assert(type.hasRank() && "expected only ranked shapes");
|
||||||
|
return MemRefType::get(type.getShape(), type.getElementType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
|
static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
|
PatternRewriter &rewriter,
|
||||||
|
bool insertDealloc,
|
||||||
|
ArrayRef<Value> operands = {}) {
|
||||||
|
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||||
|
AllocOp alloc;
|
||||||
|
if (!operands.empty()) {
|
||||||
|
auto memRefShape = type.getShape();
|
||||||
|
auto rank = memRefShape.size();
|
||||||
|
|
||||||
|
std::map<int, Value> fromOperands;
|
||||||
|
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
|
int memRefDimIdx = rank - 1 - reversedIdx;
|
||||||
|
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
|
||||||
|
Value maxDim = nullptr;
|
||||||
|
for (int i = 0; i < operands.size(); i++) {
|
||||||
|
auto operandShape =
|
||||||
|
operands[i].getType().cast<MemRefType>().getShape();
|
||||||
|
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
|
||||||
|
|
||||||
|
if (operandDimIdx < 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// In case of operations with broadcasting, the dimension of the
|
||||||
|
// alloc result is the maximum size along each dimension of the
|
||||||
|
// operands.
|
||||||
|
auto operandDim =
|
||||||
|
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
|
||||||
|
if (maxDim) {
|
||||||
|
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
|
||||||
|
operandDim, maxDim);
|
||||||
|
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
|
||||||
|
maxDim);
|
||||||
|
} else {
|
||||||
|
maxDim = operandDim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> allocOperands;
|
||||||
|
for (int i = 0; i < rank; ++i)
|
||||||
|
if (memRefShape[i] < 0)
|
||||||
|
allocOperands.push_back(fromOperands[i]);
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
|
||||||
|
} else {
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure to allocate at the beginning of the block if
|
||||||
|
// all dimensions are known.
|
||||||
|
auto *parentBlock = alloc.getOperation()->getBlock();
|
||||||
|
if (hasAllConstantDimensions(type))
|
||||||
|
alloc.getOperation()->moveBefore(&parentBlock->front());
|
||||||
|
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
|
||||||
|
return alloc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine if current function returns the result value of the
|
||||||
|
// current op being lowered. If it does then dealloc should not be
|
||||||
|
// inserted.
|
||||||
|
static bool checkInsertDealloc(Operation *currentOp) {
|
||||||
|
auto parentBlock = currentOp->getBlock();
|
||||||
|
|
||||||
|
bool insertDealloc = true;
|
||||||
|
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
||||||
|
assert(currentOp->getNumResults() < 2 &&
|
||||||
|
"No more than one result supported (for now).");
|
||||||
|
// If there is at least one result to investigate.
|
||||||
|
if (currentOp->getNumResults() > 0) {
|
||||||
|
auto result = currentOp->getResult(0);
|
||||||
|
for (const auto &operand : op.getOperands())
|
||||||
|
if (operand == result)
|
||||||
|
insertDealloc = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return insertDealloc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||||
|
// given that the result type is the result of a reduction op over the input
|
||||||
|
// type.
|
||||||
|
std::map<int64_t, int64_t>
|
||||||
|
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
||||||
|
std::map<int64_t, int64_t> OutInDimMap;
|
||||||
|
int64_t rank = inputTy.getRank();
|
||||||
|
|
||||||
|
// Mark reduction axes.
|
||||||
|
std::vector<bool> isReductionAxis;
|
||||||
|
for (decltype(rank) i = 0; i < rank; ++i) {
|
||||||
|
if (std::find(axes.begin(), axes.end(), i) != axes.end())
|
||||||
|
isReductionAxis.push_back(true);
|
||||||
|
else
|
||||||
|
isReductionAxis.push_back(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
|
||||||
|
// If it is a reduction axis, there is no relationship among dimensions.
|
||||||
|
if (isReductionAxis[inIndex]) {
|
||||||
|
if (keepdims)
|
||||||
|
outIndex++;
|
||||||
|
} else {
|
||||||
|
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
|
||||||
|
outIndex++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return OutInDimMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||||
|
// Dynamic dimenions are supported.
|
||||||
|
static void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, KrnlIterateOperandPack &pack,
|
||||||
|
Value operand, int index) {
|
||||||
|
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||||
|
if (shape[index] < 0) {
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushOperandBound(
|
||||||
|
rewriter.create<DimOp>(loc, operand, index).getResult());
|
||||||
|
} else {
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushConstantBound(shape[index]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function that defines the KRNL dialect loops and their respective
|
||||||
|
// optimized version.
|
||||||
|
static KrnlOptimizeLoopsOp
|
||||||
|
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
std::vector<Value> &loops,
|
||||||
|
std::vector<Value> &optimizedLoops, int64_t numLoops) {
|
||||||
|
// Define loops.
|
||||||
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
|
||||||
|
loops.reserve(numLoops);
|
||||||
|
for (auto result : loopsOp.getResults())
|
||||||
|
loops.push_back(result);
|
||||||
|
|
||||||
|
// Define optimized version of the loops.
|
||||||
|
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
|
||||||
|
optimizedLoops.reserve(numLoops);
|
||||||
|
for (auto result : optimizedLoopsOp.getResults())
|
||||||
|
optimizedLoops.push_back(result);
|
||||||
|
|
||||||
|
return optimizedLoopsOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function that emits the loops and their optimized version.
|
||||||
|
// The function returns a reference to the inner optimization block.
|
||||||
|
static Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
std::vector<Value> &loops,
|
||||||
|
std::vector<Value> &optimizedLoops,
|
||||||
|
int64_t numLoops) {
|
||||||
|
KrnlOptimizeLoopsOp optimizedLoopsOp =
|
||||||
|
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
|
||||||
|
return &optimizedLoopsOp.region().front();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function which emits a basic set of loops and optimized loops
|
||||||
|
// for a given operation argument. A reference to the loop optimization
|
||||||
|
// block is returned in the last argument of the function.
|
||||||
|
static void emitKrnlLoopsAndIterationForOperand(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
||||||
|
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
||||||
|
KrnlIterateOp &iterateOp) {
|
||||||
|
// Operand shape.
|
||||||
|
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||||
|
|
||||||
|
// Number of loops.
|
||||||
|
int64_t rank = shape.size();
|
||||||
|
|
||||||
|
// Define loops and optimized loops.
|
||||||
|
std::vector<Value> optimizedLoops;
|
||||||
|
optimizedLoopsOp =
|
||||||
|
emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
||||||
|
|
||||||
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
|
// Iterate over the loop nest.
|
||||||
|
for (int i = 0; i < rank; ++i)
|
||||||
|
addDimensionToPack(rewriter, loc, pack, operand, i);
|
||||||
|
|
||||||
|
iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
||||||
|
auto elementType = memRefType.getElementType();
|
||||||
|
|
||||||
|
unsigned sizeInBits;
|
||||||
|
if (elementType.isIntOrFloat()) {
|
||||||
|
sizeInBits = elementType.getIntOrFloatBitWidth();
|
||||||
|
} else {
|
||||||
|
auto vectorType = elementType.cast<VectorType>();
|
||||||
|
sizeInBits =
|
||||||
|
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
|
||||||
|
}
|
||||||
|
return llvm::divideCeil(sizeInBits, 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
|
// broadcasting.
|
||||||
|
std::map<int, std::map<int, Value>>
|
||||||
|
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
MemRefType memRefType, ArrayRef<Value> operands) {
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
int64_t rank = memRefShape.size();
|
||||||
|
// For unknown dimensions, we need to get dimension values at runtime in
|
||||||
|
// order to do broadcasting.
|
||||||
|
std::map<int, std::map<int, Value>> DimInfo;
|
||||||
|
// For each result dimension, compute the number of sharing operands.
|
||||||
|
// Sharing operands are operands sharing the same index (counting from the
|
||||||
|
// rightmost to the leftmost) for a given dimension.
|
||||||
|
std::map<int, int> sharedDimCount;
|
||||||
|
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
|
int dimIdx = rank - 1 - reversedIdx;
|
||||||
|
sharedDimCount[dimIdx] = 0;
|
||||||
|
for (int i = 0; i < operands.size(); ++i) {
|
||||||
|
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||||
|
if (reversedIdx <= shape.size() - 1)
|
||||||
|
sharedDimCount[dimIdx]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// An unknown dimension can have a value of 1 or N (N > 1).
|
||||||
|
// If its value is 1, it is broadcasted dimension.
|
||||||
|
// Otherwise, non-broadcasted dimension.
|
||||||
|
// We only care about unknown dimensions whose number of sharing operands is
|
||||||
|
// more than one, since they are potentially broadcasted dimensions.
|
||||||
|
for (int i = 0; i < operands.size(); ++i) {
|
||||||
|
std::map<int, Value> broadcastedDims;
|
||||||
|
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||||
|
int size = shape.size();
|
||||||
|
for (int j = 0; j < shape.size(); ++j) {
|
||||||
|
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
|
||||||
|
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||||
|
auto isBroadcasted =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
||||||
|
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DimInfo.insert(std::make_pair(i, broadcastedDims));
|
||||||
|
}
|
||||||
|
return DimInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract induction variables that are used for broadcasting values of a
|
||||||
|
// given operand.
|
||||||
|
std::vector<Value>
|
||||||
|
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
ArrayRef<Value> loopIVs, Value operand,
|
||||||
|
std::map<int, Value> broadcastedDims) {
|
||||||
|
// `operand` must has a ranked type. This should have been checked by the
|
||||||
|
// shape inference pass.
|
||||||
|
auto operandShape = operand.getType().cast<MemRefType>().getShape();
|
||||||
|
auto rank = operandShape.size();
|
||||||
|
auto loopCount = loopIVs.size();
|
||||||
|
|
||||||
|
std::vector<Value> newLoopIVs;
|
||||||
|
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
|
auto dimIdx = rank - 1 - reversedIdx;
|
||||||
|
auto loopIdx = loopCount - 1 - reversedIdx;
|
||||||
|
if (operandShape[dimIdx] == 1) {
|
||||||
|
// Broadcasted dimension
|
||||||
|
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
newLoopIVs.insert(newLoopIVs.begin(), zero);
|
||||||
|
} else if ((operandShape[dimIdx] == -1) &&
|
||||||
|
(broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
|
||||||
|
// Unknown dimension, it can have a value of 1 or N (N > 1).
|
||||||
|
// If its value is 1, it is broadcasted dimension.
|
||||||
|
// Otherwise, non-broadcasted dimension.
|
||||||
|
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
|
||||||
|
loopIVs[loopIdx]);
|
||||||
|
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
||||||
|
} else {
|
||||||
|
// Non-broadcasted dimension
|
||||||
|
newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newLoopIVs;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// This is to get a scalar operation of a given type for a specific operation.
|
||||||
|
template <typename Op>
|
||||||
|
struct ScalarOp {
|
||||||
|
using FOp = void;
|
||||||
|
using IOp = void;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FOp>
|
||||||
|
using ScalarFOp = typename ScalarOp<FOp>::FOp;
|
||||||
|
template <typename IOp>
|
||||||
|
using ScalarIOp = typename ScalarOp<IOp>::IOp;
|
||||||
|
|
||||||
|
// Get the identity element of a operation.
|
||||||
|
// Return NULL if the function does not have identity.
|
||||||
|
template <typename DataType, typename Op>
|
||||||
|
DataType getIdentityValue() {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This is used in the innermost loop of a KrnlIterateOp to insert computation
|
||||||
|
// composed of one or many scalar ops.
|
||||||
|
// Use template specialization for each of different ONNX operations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <typename Op>
|
||||||
|
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Type element_type = operands.front().getType();
|
||||||
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
|
||||||
|
mlir::None);
|
||||||
|
} else if (element_type.isa<FloatType>()) {
|
||||||
|
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
|
||||||
|
mlir::None);
|
||||||
|
} else {
|
||||||
|
emitError(loc, "unsupported element type");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We divide the operator lowering into different categories.
|
||||||
|
// These categories are mostly similar to the operator categories in ONNX:
|
||||||
|
// https://github.com/onnx/onnx/tree/master/onnx/defs.
|
||||||
|
// Besides, it is better to put operators with the same computation pattern into
|
||||||
|
// the same category, e.g. element-wise operators will belong to the elementwise
|
||||||
|
// category.
|
||||||
|
|
||||||
|
// Math
|
||||||
|
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc"
|
||||||
|
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc"
|
||||||
|
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc"
|
||||||
|
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc"
|
||||||
|
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc"
|
||||||
|
// Tensor
|
||||||
|
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc"
|
||||||
|
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc"
|
||||||
|
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc"
|
||||||
|
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
|
||||||
|
// Neural network
|
||||||
|
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// EntryPoint Op lowering to Krnl Entry Point.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(ONNXEntryPointOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(
|
||||||
|
op,
|
||||||
|
op.getAttrOfType<SymbolRefAttr>(
|
||||||
|
ONNXEntryPointOp::getEntryPointFuncAttrName()),
|
||||||
|
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
|
||||||
|
op.getAttrOfType<IntegerAttr>(
|
||||||
|
ONNXEntryPointOp::getNumOutputsAttrName()));
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct TensorTypeConverter : public TypeConverter {
|
||||||
|
using TypeConverter::TypeConverter;
|
||||||
|
|
||||||
|
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
|
||||||
|
if (auto tensor_type = t.dyn_cast<TensorType>()) {
|
||||||
|
results.push_back(convertTensorToMemRef(tensor_type));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
results.push_back(t);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return true if the inputs and outputs of the given function type are
|
||||||
|
/// legal. [Taken from MLIR and adapted to only check the legality of the
|
||||||
|
/// inputs. Once unranked results can be handled gracefully this
|
||||||
|
/// override needs to be removed in favour of the original MLIR one.]
|
||||||
|
bool isSignatureLegal(FunctionType funcType) {
|
||||||
|
return llvm::all_of(funcType.getInputs(),
|
||||||
|
[this](Type type) { return isLegal(type); });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end anonymous namespace.
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Frontend to Krnl Dialect lowering pass
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// This is a partial lowering to Krnl loops of the ONNX operations.
|
||||||
|
namespace {
|
||||||
|
struct FrontendToKrnlLoweringPass
|
||||||
|
: public ModulePass<FrontendToKrnlLoweringPass> {
|
||||||
|
void runOnModule() final;
|
||||||
|
};
|
||||||
|
} // end anonymous namespace.
|
||||||
|
|
||||||
|
void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
|
auto module = getModule();
|
||||||
|
|
||||||
|
// The first thing to define is the conversion target. This will define the
|
||||||
|
// final target for this lowering.
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
|
||||||
|
// We define the specific operations, or dialects, that are legal targets for
|
||||||
|
// this lowering.
|
||||||
|
target
|
||||||
|
.addLegalDialect<KrnlOpsDialect, AffineOpsDialect, StandardOpsDialect>();
|
||||||
|
|
||||||
|
// TODO: enable this once more ops are supported.
|
||||||
|
// We also define the ONNX dialect as Illegal so that the conversion will fail
|
||||||
|
// if any of these operations are *not* converted.
|
||||||
|
// target.addIllegalDialect<mlir::ONNXOpsDialect>();
|
||||||
|
|
||||||
|
// TODO: add any other ops which are considered legal.
|
||||||
|
// Some operations can be marked as being still legal.
|
||||||
|
// Example: target.addLegalOp<mlir::OpName>();
|
||||||
|
|
||||||
|
// Now that the conversion target has been defined, we just need to provide
|
||||||
|
// the set of patterns that will lower the frontend operations.
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
|
||||||
|
// Convert TensorType to MemRef
|
||||||
|
TensorTypeConverter tensor_to_memref_converter;
|
||||||
|
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||||
|
// FuncOp is legal only if types have been converted to Std types.
|
||||||
|
return tensor_to_memref_converter.isSignatureLegal(op.getType());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Type conversion for function signatures.
|
||||||
|
// Call MLIR FuncOp signature conversion when result type is
|
||||||
|
// a ranked tensor.
|
||||||
|
populateFuncOpTypeConversionPattern(patterns, &getContext(),
|
||||||
|
tensor_to_memref_converter);
|
||||||
|
|
||||||
|
// Frontend operation lowering.
|
||||||
|
// Math
|
||||||
|
populateLoweringONNXElementwiseOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXGemmOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXReductionOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXSoftmaxOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXMatMulOpPattern(patterns, &getContext());
|
||||||
|
// Tensor
|
||||||
|
populateLoweringONNXReshapeOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||||
|
// Neural network
|
||||||
|
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||||
|
// Entry point
|
||||||
|
patterns.insert<ONNXEntryPointLowering>(&getContext());
|
||||||
|
|
||||||
|
// With the target and rewrite patterns defined, we can now attempt the
|
||||||
|
// conversion. The conversion will signal failure if any of our `illegal`
|
||||||
|
// operations were not converted successfully.
|
||||||
|
if (failed(applyPartialConversion(module, target, patterns)))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
|
||||||
|
return std::make_unique<FrontendToKrnlLoweringPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<FrontendToKrnlLoweringPass>
|
||||||
|
pass("lower-frontend", "Lower frontend ops to Krnl dialect.");
|
|
@ -0,0 +1,646 @@
|
||||||
|
//===----- elementwise.inc - Elementwise Ops ------------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers ONNX element-wise operators to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXAddOp> {
|
||||||
|
using FOp = AddFOp;
|
||||||
|
using IOp = AddIOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXMulOp> {
|
||||||
|
using FOp = MulFOp;
|
||||||
|
using IOp = MulIOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXDivOp> {
|
||||||
|
using FOp = DivFOp;
|
||||||
|
using IOp = SignedDivIOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXSubOp> {
|
||||||
|
using FOp = SubFOp;
|
||||||
|
using IOp = SubIOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXAndOp> {
|
||||||
|
using FOp = AndOp; // not use
|
||||||
|
using IOp = AndOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXOrOp> {
|
||||||
|
using FOp = OrOp; // not use
|
||||||
|
using IOp = OrOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXXorOp> {
|
||||||
|
using FOp = XOrOp; // not use
|
||||||
|
using IOp = XOrOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXExpOp> {
|
||||||
|
using FOp = ExpOp;
|
||||||
|
using IOp = ExpOp; // not use
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXSumOp> {
|
||||||
|
using FOp = AddFOp;
|
||||||
|
using IOp = AddIOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXTanhOp> {
|
||||||
|
using FOp = TanhOp;
|
||||||
|
using IOp = TanhOp; // not use
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXCosOp> {
|
||||||
|
using FOp = CosOp;
|
||||||
|
using IOp = CosOp; // not use
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXLogOp> {
|
||||||
|
using FOp = LogOp;
|
||||||
|
using IOp = LogOp; // not use
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXSqrtOp> {
|
||||||
|
using FOp = KrnlSqrtOp;
|
||||||
|
using IOp = KrnlSqrtOp; // not use
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXSinhOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
|
// ConstantOp 2)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
|
||||||
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
|
auto result = rewriter.create<DivFOp>(
|
||||||
|
loc, rewriter.create<SubFOp>(loc, exp, negExp), two);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXCoshOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
|
// ConstantOp 2)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
|
||||||
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
|
auto result = rewriter.create<DivFOp>(
|
||||||
|
loc, rewriter.create<AddFOp>(loc, exp, negExp), two);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXSigmoidOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
|
||||||
|
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
|
auto result = rewriter.create<DivFOp>(
|
||||||
|
loc, one, rewriter.create<AddFOp>(loc, one, negExp));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXHardSigmoidOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||||
|
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// %Y = AddFOp(MulFOp(alpha, %X), beta)
|
||||||
|
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
|
||||||
|
// %Y,
|
||||||
|
// Constant 0)
|
||||||
|
// ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1),
|
||||||
|
// %Z,
|
||||||
|
// Constant 1)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
|
llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat());
|
||||||
|
auto betaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
|
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||||
|
auto beta = rewriter.create<ConstantOp>(loc, betaAttribute);
|
||||||
|
|
||||||
|
auto add = rewriter.create<AddFOp>(
|
||||||
|
loc, rewriter.create<MulFOp>(loc, alpha, operand), beta);
|
||||||
|
auto maxPredicate =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, add, zero);
|
||||||
|
auto max = rewriter.create<SelectOp>(loc, maxPredicate, add, zero);
|
||||||
|
auto minPredicate =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, max, one);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, minPredicate, max, one);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXEluOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
|
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
|
||||||
|
// %X)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
|
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||||
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
|
auto lessThanZero =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
|
auto result = rewriter.create<SelectOp>(
|
||||||
|
loc, lessThanZero,
|
||||||
|
rewriter.create<MulFOp>(loc, alpha,
|
||||||
|
rewriter.create<SubFOp>(loc, exp, one)),
|
||||||
|
operand);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXReluOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
|
// ConstantOp 0,
|
||||||
|
// %X)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
auto lessThanZero =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXLeakyReluOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
|
// MulFOp(alpha, %X),
|
||||||
|
// %X)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
|
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||||
|
auto lessThanZero =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
|
auto result = rewriter.create<SelectOp>(
|
||||||
|
loc, lessThanZero, rewriter.create<MulFOp>(loc, alpha, operand), operand);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXSeluOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
|
||||||
|
// MulFOp(gamma, %X),
|
||||||
|
// MulFOp(gamma,
|
||||||
|
// SubFOp(MulFOp(alpha, ExpOp(%X)),
|
||||||
|
// alpha)))
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
|
llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
|
||||||
|
auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
|
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||||
|
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttribute);
|
||||||
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
|
auto greaterThanZero =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
||||||
|
auto select = rewriter.create<SelectOp>(
|
||||||
|
loc, greaterThanZero, operand,
|
||||||
|
rewriter.create<SubFOp>(loc, rewriter.create<MulFOp>(loc, alpha, exp),
|
||||||
|
alpha));
|
||||||
|
auto result = rewriter.create<MulFOp>(loc, gamma, select);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXReciprocalOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXReciprocalOp>(
|
||||||
|
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
|
auto result = rewriter.create<DivFOp>(loc, one, operand);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXSoftplusOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXSoftplusOp>(
|
||||||
|
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1))
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
|
auto add = rewriter.create<AddFOp>(loc, exp, one);
|
||||||
|
auto result = rewriter.create<LogOp>(loc, add);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXSoftsignOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXSoftsignOp>(
|
||||||
|
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto abs = rewriter.create<AbsFOp>(loc, operand);
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
|
auto add = rewriter.create<AddFOp>(loc, abs, one);
|
||||||
|
auto result = rewriter.create<DivFOp>(loc, operand, add);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXSignOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
Type element_type = operands.front().getType();
|
||||||
|
// TODO: unsigned int should be supported separately?
|
||||||
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
// %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0),
|
||||||
|
// ConstantOp 1,
|
||||||
|
// COnstantOp -1)
|
||||||
|
// ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0),
|
||||||
|
// ConstantOp 0,
|
||||||
|
// %Y)
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
|
||||||
|
auto minusOne =
|
||||||
|
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(-1));
|
||||||
|
auto plusPredicate =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, operand, zero);
|
||||||
|
auto plusSelect =
|
||||||
|
rewriter.create<SelectOp>(loc, plusPredicate, one, minusOne);
|
||||||
|
auto zeroPredicate =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, operand, zero);
|
||||||
|
auto result =
|
||||||
|
rewriter.create<SelectOp>(loc, zeroPredicate, zero, plusSelect);
|
||||||
|
return result;
|
||||||
|
} else if (element_type.isa<FloatType>()) {
|
||||||
|
// %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0),
|
||||||
|
// ConstantOp 1,
|
||||||
|
// ConstantOp -1)
|
||||||
|
// ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0),
|
||||||
|
// ConstantOp 0,
|
||||||
|
// %Y)
|
||||||
|
auto zero =
|
||||||
|
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||||
|
auto minusOne =
|
||||||
|
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0f));
|
||||||
|
auto plusPredicate =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
||||||
|
auto plusSelect =
|
||||||
|
rewriter.create<SelectOp>(loc, plusPredicate, one, minusOne);
|
||||||
|
auto zeroPredicate =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, operand, zero);
|
||||||
|
auto result =
|
||||||
|
rewriter.create<SelectOp>(loc, zeroPredicate, zero, plusSelect);
|
||||||
|
return result;
|
||||||
|
} else {
|
||||||
|
emitError(loc, "unsupported element type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXMaxOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
|
||||||
|
// %X,
|
||||||
|
// %Y)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value lhs = operands[0];
|
||||||
|
Value rhs = operands[1];
|
||||||
|
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXMinOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
|
||||||
|
// %X,
|
||||||
|
// %Y)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value lhs = operands[0];
|
||||||
|
Value rhs = operands[1];
|
||||||
|
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Element-wise unary ops lowering to Krnl dialect.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <typename ElementwiseUnaryOp>
|
||||||
|
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
|
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
// TODO: Check that the types are valid.
|
||||||
|
// An element-wise unary operation must have all operands and the result of
|
||||||
|
// the same type. This should have been verified by the verifier.
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
|
||||||
|
// If the output has a dynamic dimension, pass the operands required for
|
||||||
|
// each dynamic dimension to the AllocOp. The first operand of the
|
||||||
|
// operation is used. The operands of the op need to match in terms of
|
||||||
|
// dimensions with the result at this pre-optimization phase.
|
||||||
|
// TODO: verify that dimensions match.
|
||||||
|
// TODO: can the dimension of the result differ after optimizations?
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||||
|
{operands[0]});
|
||||||
|
|
||||||
|
std::vector<Value> originalLoops;
|
||||||
|
KrnlOptimizeLoopsOp optimizedLoopsOp;
|
||||||
|
KrnlIterateOp iterateOp;
|
||||||
|
emitKrnlLoopsAndIterationForOperand(
|
||||||
|
rewriter, loc, operands[0], originalLoops,
|
||||||
|
optimizedLoopsOp, iterateOp);
|
||||||
|
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
||||||
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
|
// 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
|
||||||
|
rewriter.setInsertionPointToEnd(&optimizationBlock);
|
||||||
|
// Return from KrnlOptimizeLoopsOp body.
|
||||||
|
// When no optimizations are present we just return the loops
|
||||||
|
// unchaged.
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
|
||||||
|
// 2. Insert instructions inside the KernelIterateOp body.
|
||||||
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
|
// Handle the operation:
|
||||||
|
SmallVector<Value, 4> loopIVs;
|
||||||
|
for (auto arg : iterationBlock.getArguments())
|
||||||
|
loopIVs.push_back(arg);
|
||||||
|
|
||||||
|
auto loadedVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
||||||
|
auto loweredOpResult = mapToLowerScalarOp<ElementwiseUnaryOp>(
|
||||||
|
op, memRefType.getElementType(), {loadedVal}, rewriter);
|
||||||
|
// Store result in the resulting array.
|
||||||
|
rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Element-wise variadic ops lowering to Krnl dialect.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <typename ElementwiseVariadicOp>
|
||||||
|
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
|
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
// TODO: Check that the types are valid.
|
||||||
|
// An element-wise variadic operation must have all operands and the result
|
||||||
|
// of the same type. This should have been verified by the verifier.
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto numArgs = op->getNumOperands();
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
// If the output has a dynamic dimension, we compute its dimension at
|
||||||
|
// runtime by using dimensions from the operands.
|
||||||
|
// In particular, we need to know from which operand a result dimension
|
||||||
|
// comes from.
|
||||||
|
// TODO: can the dimension of the result differ after optimizations?
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||||
|
operands);
|
||||||
|
|
||||||
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
|
// broadcasting.
|
||||||
|
std::map<int, std::map<int, Value>> broadcastedDimInfo =
|
||||||
|
getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
|
||||||
|
|
||||||
|
std::vector<Value> originalLoops;
|
||||||
|
KrnlOptimizeLoopsOp optimizedLoopsOp;
|
||||||
|
KrnlIterateOp iterateOp;
|
||||||
|
emitKrnlLoopsAndIterationForOperand(
|
||||||
|
rewriter, loc, alloc, originalLoops,
|
||||||
|
optimizedLoopsOp, iterateOp);
|
||||||
|
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
||||||
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
|
// 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
|
||||||
|
rewriter.setInsertionPointToEnd(&optimizationBlock);
|
||||||
|
// Return from KrnlOptimizeLoopsOp body.
|
||||||
|
// When no optimizations are present we just return the loops unchaged.
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
|
||||||
|
// 2. Insert instructions inside the KernelIterateOp body.
|
||||||
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
|
// Handle the operation:
|
||||||
|
SmallVector<Value, 4> loopIVs;
|
||||||
|
for (auto arg : iterationBlock.getArguments())
|
||||||
|
loopIVs.push_back(arg);
|
||||||
|
|
||||||
|
// Fold over operands for each of their scalar values
|
||||||
|
Value accumulated, next;
|
||||||
|
auto accumulatedLoopIVs = getLoopIVsForBroadcasting(
|
||||||
|
loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]);
|
||||||
|
accumulated = rewriter.create<LoadOp>(loc, operands[0], accumulatedLoopIVs);
|
||||||
|
for (unsigned i = 1; i < numArgs; i++) {
|
||||||
|
auto nextLoopIVs = getLoopIVsForBroadcasting(
|
||||||
|
loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]);
|
||||||
|
next = rewriter.create<LoadOp>(loc, operands[i], nextLoopIVs);
|
||||||
|
accumulated = mapToLowerScalarOp<ElementwiseVariadicOp>(
|
||||||
|
op, memRefType.getElementType(), {accumulated, next}, rewriter);
|
||||||
|
}
|
||||||
|
// Store result in the resulting array.
|
||||||
|
rewriter.create<StoreOp>(loc, accumulated, alloc, loopIVs);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXElementwiseOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXExpOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXLogOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSignOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx);
|
||||||
|
}
|
|
@ -0,0 +1,209 @@
|
||||||
|
//===----- gemm.inc - Lowering Gemm Op ------------------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Gemm Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
|
ONNXGemmOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
Value A, B, C;
|
||||||
|
A = operands[0];
|
||||||
|
B = operands[1];
|
||||||
|
C = operands[2];
|
||||||
|
|
||||||
|
auto alphaAttr = FloatAttr::get(tensorType.getElementType(),
|
||||||
|
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
|
||||||
|
auto betaAttr = FloatAttr::get(tensorType.getElementType(),
|
||||||
|
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
|
||||||
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
|
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||||
|
|
||||||
|
bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
|
||||||
|
bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
|
||||||
|
|
||||||
|
// Result type
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else {
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
SmallVector<Value, 2> allocOperands;
|
||||||
|
if (memRefShape[0] < 0) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, A, (isTransA) ? 1 : 0);
|
||||||
|
allocOperands.emplace_back(dim);
|
||||||
|
}
|
||||||
|
if (memRefShape[1] < 0) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, B, (isTransB) ? 0 : 1);
|
||||||
|
allocOperands.emplace_back(dim);
|
||||||
|
}
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto *parentBlock = alloc.getDefiningOp()->getBlock();
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of loops
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
int64_t numLoops = 3;
|
||||||
|
|
||||||
|
// Define loops.
|
||||||
|
std::vector<Value> originalLoops;
|
||||||
|
std::vector<Value> optimizedLoops;
|
||||||
|
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
||||||
|
optimizedLoops, numLoops);
|
||||||
|
|
||||||
|
// We have two Krnl loops:
|
||||||
|
// - Outer loop iterates over the output matrix dimensions, and
|
||||||
|
// - Reduction loop iterates over the reduction dimension.
|
||||||
|
|
||||||
|
// Outer loop
|
||||||
|
std::vector<Value> outerLoops, optimizedOuterLoops;
|
||||||
|
outerLoops.reserve(2);
|
||||||
|
optimizedOuterLoops.reserve(2);
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
outerLoops.push_back(originalLoops[i]);
|
||||||
|
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||||
|
}
|
||||||
|
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
|
||||||
|
optimizedOuterLoops);
|
||||||
|
// Induction variables for the outer loops
|
||||||
|
for (int i = 0; i < 2; ++i)
|
||||||
|
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
|
||||||
|
|
||||||
|
// Reduction loop
|
||||||
|
std::vector<Value> reductionLoops, optimizedReductionLoops;
|
||||||
|
reductionLoops.reserve(1);
|
||||||
|
optimizedReductionLoops.reserve(1);
|
||||||
|
reductionLoops.push_back(originalLoops[2]);
|
||||||
|
optimizedReductionLoops.push_back(optimizedLoops[2]);
|
||||||
|
KrnlIterateOperandPack reductionPack(rewriter, reductionLoops,
|
||||||
|
optimizedReductionLoops);
|
||||||
|
// Induction variable for the reduction dimension
|
||||||
|
// Try to find and use a static value from A or B first.
|
||||||
|
// If it failed then use a dynamic value.
|
||||||
|
auto ATy = A.getType().cast<MemRefType>();
|
||||||
|
auto BTy = B.getType().cast<MemRefType>();
|
||||||
|
int64_t K_A_Idx = (isTransA) ? 0 : 1;
|
||||||
|
int64_t K_B_Idx = (isTransB) ? 1 : 0;
|
||||||
|
reductionPack.pushConstantBound(0);
|
||||||
|
if (ATy.getShape()[K_A_Idx] != -1)
|
||||||
|
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
|
||||||
|
else
|
||||||
|
if (BTy.getShape()[K_B_Idx] != -1)
|
||||||
|
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
|
||||||
|
else
|
||||||
|
reductionPack.pushOperandBound(
|
||||||
|
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
|
||||||
|
|
||||||
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
|
// broadcasting.
|
||||||
|
// GemmOp supports unidirectional broadcasting from C to A*B.
|
||||||
|
// Hence, it must be enough to get broadcasting information for C only.
|
||||||
|
std::map<int, Value> broadcastedDimInfo;
|
||||||
|
auto shape = C.getType().cast<MemRefType>().getShape();
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (shape[i] < 0) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, C, i).getResult();
|
||||||
|
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||||
|
auto isBroadcasted =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
||||||
|
broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
|
||||||
|
|
||||||
|
// Now perform the insertions into the body of the
|
||||||
|
// just generated instructions:
|
||||||
|
|
||||||
|
// No optimization
|
||||||
|
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
|
||||||
|
// Insert instructions inside the outer loop.
|
||||||
|
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
||||||
|
|
||||||
|
// Induction variables
|
||||||
|
SmallVector<Value, 4> loopMNIVs;
|
||||||
|
for (auto arg : outerIterationBlock.getArguments()) {
|
||||||
|
loopMNIVs.emplace_back(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the output of A*B
|
||||||
|
auto zero = rewriter.create<ConstantOp>(
|
||||||
|
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
||||||
|
rewriter.create<StoreOp>(loc, zero, alloc, loopMNIVs);
|
||||||
|
|
||||||
|
// Compute A*B
|
||||||
|
auto matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, reductionPack);
|
||||||
|
|
||||||
|
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
|
||||||
|
auto loopCIVs = getLoopIVsForBroadcasting(
|
||||||
|
loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
|
||||||
|
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
||||||
|
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
||||||
|
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
||||||
|
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
|
||||||
|
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
|
||||||
|
rewriter.create<StoreOp>(loc, Y, alloc, loopMNIVs);
|
||||||
|
|
||||||
|
// Insert instructions to do matrix multiplication: A*B
|
||||||
|
Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&matmulIterationBlock);
|
||||||
|
|
||||||
|
// Induction variables
|
||||||
|
SmallVector<Value, 4> loopKIVs, loopAIVs, loopBIVs;
|
||||||
|
for (auto arg : matmulIterationBlock.getArguments())
|
||||||
|
loopKIVs.emplace_back(arg);
|
||||||
|
if (isTransA) {
|
||||||
|
loopAIVs.emplace_back(loopKIVs[0]);
|
||||||
|
loopAIVs.emplace_back(loopMNIVs[0]);
|
||||||
|
} else {
|
||||||
|
loopAIVs.emplace_back(loopMNIVs[0]);
|
||||||
|
loopAIVs.emplace_back(loopKIVs[0]);
|
||||||
|
}
|
||||||
|
if (isTransB) {
|
||||||
|
loopBIVs.emplace_back(loopMNIVs[1]);
|
||||||
|
loopBIVs.emplace_back(loopKIVs[0]);
|
||||||
|
} else {
|
||||||
|
loopBIVs.emplace_back(loopKIVs[0]);
|
||||||
|
loopBIVs.emplace_back(loopMNIVs[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matmul computation
|
||||||
|
auto loadedA = rewriter.create<LoadOp>(loc, A, loopAIVs);
|
||||||
|
auto loadedB = rewriter.create<LoadOp>(loc, B, loopBIVs);
|
||||||
|
auto loadedY = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
||||||
|
auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
|
||||||
|
auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
|
||||||
|
rewriter.create<StoreOp>(loc, accumulated, alloc, loopMNIVs);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXGemmOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXGemmOpLowering>(ctx);
|
||||||
|
}
|
|
@ -0,0 +1,345 @@
|
||||||
|
//===----- matmul.inc - Lowering Matmul Op --------------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Matmul Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
|
ONNXMatMulOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
Value A = operands[0];
|
||||||
|
Value B = operands[1];
|
||||||
|
auto AShape = A.getType().cast<MemRefType>().getShape();
|
||||||
|
auto BShape = B.getType().cast<MemRefType>().getShape();
|
||||||
|
|
||||||
|
// There are three cases related to the shapes of the two arguments:
|
||||||
|
// - Both arguments are N-D, N >= 2
|
||||||
|
// - Either argument is 1-D, the other is N-D, N >= 2
|
||||||
|
// - Both arguments are 1-D
|
||||||
|
|
||||||
|
// Result type
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
auto elementType = memRefType.getElementType();
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
|
||||||
|
// A value zero
|
||||||
|
Value zero;
|
||||||
|
if (elementType.isa<IntegerType>()) {
|
||||||
|
zero = rewriter.create<ConstantOp>(
|
||||||
|
loc, IntegerAttr::get(memRefType.getElementType(), 0));
|
||||||
|
} else if (elementType.isa<FloatType>()) {
|
||||||
|
zero = rewriter.create<ConstantOp>(
|
||||||
|
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
||||||
|
} else {
|
||||||
|
emitError(loc, "unsupported element type");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else {
|
||||||
|
SmallVector<Value, 4> allocOperands;
|
||||||
|
if (AShape.size() >= 2 && BShape.size() >= 2) {
|
||||||
|
// Both arguments are N-D, N >= 2
|
||||||
|
// (s1 x s2 x... x sK x M x K) MATMUL (K x N)
|
||||||
|
// =>
|
||||||
|
// (s1 x s2 x... x sK x M x N)
|
||||||
|
for (int i = 0; i < memRefShape.size() - 2; ++i) {
|
||||||
|
if (memRefShape[i] < 0) {
|
||||||
|
if ((AShape.size() == 2) && (BShape.size() > 2))
|
||||||
|
allocOperands.emplace_back(rewriter.create<DimOp>(loc, B, i));
|
||||||
|
else if ((AShape.size() > 2) && (BShape.size() == 2))
|
||||||
|
allocOperands.emplace_back(rewriter.create<DimOp>(loc, A, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (memRefShape[memRefShape.size() - 2] < 0) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, A, memRefShape.size() - 2);
|
||||||
|
allocOperands.emplace_back(dim);
|
||||||
|
}
|
||||||
|
if (memRefShape[memRefShape.size() - 1] < 0) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, B, memRefShape.size() - 1);
|
||||||
|
allocOperands.emplace_back(dim);
|
||||||
|
}
|
||||||
|
} else if (AShape.size() == 1 && BShape.size() >= 2) {
|
||||||
|
// Either argument is 1-D
|
||||||
|
// K MATMUL (s1 x s2 x... x sK x K x N)
|
||||||
|
// =>
|
||||||
|
// (s1 x s2 x... x sK x N)
|
||||||
|
for (int i = 0; i < memRefShape.size() - 1; ++i) {
|
||||||
|
if (memRefShape[i] < 0) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, B, i);
|
||||||
|
allocOperands.emplace_back(dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (memRefShape[memRefShape.size() - 1] < 0) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, B, BShape.size() - 1);
|
||||||
|
allocOperands.emplace_back(dim);
|
||||||
|
}
|
||||||
|
} else if (AShape.size() >= 2 && BShape.size() == 1) {
|
||||||
|
// Either argument is 1-D
|
||||||
|
// (s1 x s2 x... x sK x M x K) MATMUL K
|
||||||
|
// =>
|
||||||
|
// (s1 x s2 x... x sK x M)
|
||||||
|
for (int i = 0; i < memRefShape.size() - 1; ++i) {
|
||||||
|
if (memRefShape[i] < 0) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, A, i);
|
||||||
|
allocOperands.emplace_back(dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (memRefShape[memRefShape.size() - 1] < 0) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, A, AShape.size() - 2);
|
||||||
|
allocOperands.emplace_back(dim);
|
||||||
|
}
|
||||||
|
} else if (AShape.size() == 1 && BShape.size() == 1) {
|
||||||
|
// Both arguments are 1-D
|
||||||
|
if (memRefShape[0] < 0) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, A, 0);
|
||||||
|
allocOperands.emplace_back(dim);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
emitError(loc, "Invalid shapes");
|
||||||
|
}
|
||||||
|
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (AShape.size() >= 2 || BShape.size() >= 2) {
|
||||||
|
// Cases 1 and 2:
|
||||||
|
// - Both arguments are N-D, N >= 2
|
||||||
|
// - Either argument is 1-D, the other is N-D, N >= 2
|
||||||
|
|
||||||
|
// Define loops for batch dimensions.
|
||||||
|
std::vector<Value> originalLoops;
|
||||||
|
std::vector<Value> optimizedLoops;
|
||||||
|
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
||||||
|
optimizedLoops, memRefShape.size());
|
||||||
|
|
||||||
|
// Outer KrnlIterateOp
|
||||||
|
SmallVector<Value, 4> loopBatchIVs;
|
||||||
|
bool hasBatchLoop = false;
|
||||||
|
if (AShape.size() > 2 || BShape.size() > 2) {
|
||||||
|
SmallVector<int, 4> batchAxes;
|
||||||
|
int matmulResultDims =
|
||||||
|
((AShape.size() == 1 || BShape.size() == 1)) ? 1 : 2;
|
||||||
|
for (int i = 0; i < memRefShape.size() - matmulResultDims; ++i)
|
||||||
|
batchAxes.emplace_back(i);
|
||||||
|
|
||||||
|
std::vector<Value> outerLoops, optimizedOuterLoops;
|
||||||
|
outerLoops.reserve(batchAxes.size());
|
||||||
|
optimizedOuterLoops.reserve(batchAxes.size());
|
||||||
|
for (int i = 0; i < batchAxes.size(); ++i) {
|
||||||
|
outerLoops.push_back(originalLoops[i]);
|
||||||
|
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||||
|
}
|
||||||
|
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
|
||||||
|
optimizedOuterLoops);
|
||||||
|
for (int i = 0; i < batchAxes.size(); ++i) {
|
||||||
|
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
|
||||||
|
}
|
||||||
|
auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
|
||||||
|
|
||||||
|
// No optimization
|
||||||
|
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
|
||||||
|
// Insert instructions into the outer KrnlIterateOp.
|
||||||
|
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
||||||
|
|
||||||
|
// Induction variables: non-matrix-multiplication variables.
|
||||||
|
for (auto arg : outerIterationBlock.getArguments()) {
|
||||||
|
loopBatchIVs.emplace_back(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
hasBatchLoop = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now, we define loops for matrix multiplication.
|
||||||
|
|
||||||
|
// Create a KrnlIterateOp for matrix multiplication.
|
||||||
|
KrnlIterateOp matmulIterateOp;
|
||||||
|
std::vector<Value> matmulLoops, optimizedMatmulLoops;
|
||||||
|
if (AShape.size() >= 2 && BShape.size() >= 2) {
|
||||||
|
// 2-D x 2-D. Result has two dimensions.
|
||||||
|
matmulLoops.reserve(2);
|
||||||
|
optimizedMatmulLoops.reserve(2);
|
||||||
|
for (int i = 2; i > 0; --i) {
|
||||||
|
matmulLoops.emplace_back(originalLoops[memRefShape.size() - i]);
|
||||||
|
optimizedMatmulLoops.emplace_back(
|
||||||
|
optimizedLoops[memRefShape.size() - i]);
|
||||||
|
}
|
||||||
|
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
|
||||||
|
optimizedMatmulLoops);
|
||||||
|
for (int i = 2; i > 0; --i) {
|
||||||
|
addDimensionToPack(rewriter, loc, matmulPack, alloc,
|
||||||
|
memRefShape.size() - i);
|
||||||
|
}
|
||||||
|
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
|
||||||
|
} else {
|
||||||
|
// 1-D x 2-D, and vice versa. Result has one dimension.
|
||||||
|
matmulLoops.reserve(1);
|
||||||
|
optimizedMatmulLoops.reserve(1);
|
||||||
|
matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
|
||||||
|
optimizedMatmulLoops.emplace_back(
|
||||||
|
optimizedLoops[memRefShape.size() - 1]);
|
||||||
|
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
|
||||||
|
optimizedMatmulLoops);
|
||||||
|
addDimensionToPack(rewriter, loc, matmulPack, alloc,
|
||||||
|
memRefShape.size() - 1);
|
||||||
|
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasBatchLoop) {
|
||||||
|
// No optimization
|
||||||
|
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert instructions into the matmul KrnlIterateOp.
|
||||||
|
Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&matmulIterationBlock);
|
||||||
|
|
||||||
|
// Induction variables: M, N
|
||||||
|
SmallVector<Value, 4> loopMNIVs;
|
||||||
|
for (auto arg : matmulIterationBlock.getArguments()) {
|
||||||
|
loopMNIVs.emplace_back(arg);
|
||||||
|
}
|
||||||
|
// Induction variables for the final result.
|
||||||
|
SmallVector<Value, 4> loopBatchMNIVs;
|
||||||
|
for (auto arg : loopBatchIVs) {
|
||||||
|
loopBatchMNIVs.emplace_back(arg);
|
||||||
|
}
|
||||||
|
for (auto arg : loopMNIVs) {
|
||||||
|
loopBatchMNIVs.emplace_back(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill the output with value 0.
|
||||||
|
rewriter.create<StoreOp>(loc, zero, alloc, loopBatchMNIVs);
|
||||||
|
|
||||||
|
// Iterate along the reduction dimension.
|
||||||
|
// Use a value from A.
|
||||||
|
std::vector<Value> reduceLoops;
|
||||||
|
std::vector<Value> optimizedReduceLoops;
|
||||||
|
Block *optimizationReduceBlock =
|
||||||
|
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
|
||||||
|
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
|
||||||
|
optimizedReduceLoops);
|
||||||
|
addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
|
||||||
|
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
|
||||||
|
|
||||||
|
// No optimization
|
||||||
|
rewriter.setInsertionPointToEnd(optimizationReduceBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, reduceLoops);
|
||||||
|
|
||||||
|
// Insert instructions into the reduction KrnlIterateOp.
|
||||||
|
Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&reduceIterationBlock);
|
||||||
|
|
||||||
|
// Induction variables
|
||||||
|
SmallVector<Value, 4> loopKIVs, loopBatchMKIVs, loopBatchKNIVs;
|
||||||
|
// K
|
||||||
|
loopKIVs.emplace_back(reduceIterationBlock.getArguments()[0]);
|
||||||
|
// MK
|
||||||
|
if (AShape.size() > 2)
|
||||||
|
for (auto arg : loopBatchIVs)
|
||||||
|
loopBatchMKIVs.emplace_back(arg);
|
||||||
|
if (AShape.size() >= 2)
|
||||||
|
loopBatchMKIVs.emplace_back(loopMNIVs[0]);
|
||||||
|
loopBatchMKIVs.emplace_back(loopKIVs[0]);
|
||||||
|
// KN
|
||||||
|
if (BShape.size() > 2)
|
||||||
|
for (auto arg : loopBatchIVs)
|
||||||
|
loopBatchKNIVs.emplace_back(arg);
|
||||||
|
loopBatchKNIVs.emplace_back(loopKIVs[0]);
|
||||||
|
if (BShape.size() >= 2)
|
||||||
|
if (AShape.size() >= 2)
|
||||||
|
loopBatchKNIVs.emplace_back(loopMNIVs[1]);
|
||||||
|
else
|
||||||
|
loopBatchKNIVs.emplace_back(loopMNIVs[0]);
|
||||||
|
|
||||||
|
// Matmul computation
|
||||||
|
auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs);
|
||||||
|
auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs);
|
||||||
|
auto loadedY = rewriter.create<LoadOp>(loc, alloc, loopBatchMNIVs);
|
||||||
|
if (elementType.isa<IntegerType>()) {
|
||||||
|
auto AB = rewriter.create<MulIOp>(loc, loadedA, loadedB);
|
||||||
|
auto accumulated = rewriter.create<AddIOp>(loc, loadedY, AB);
|
||||||
|
rewriter.create<StoreOp>(loc, accumulated, alloc, loopBatchMNIVs);
|
||||||
|
} else if (elementType.isa<FloatType>()) {
|
||||||
|
auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
|
||||||
|
auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
|
||||||
|
rewriter.create<StoreOp>(loc, accumulated, alloc, loopBatchMNIVs);
|
||||||
|
}
|
||||||
|
} else if ((AShape.size() == 1) && (BShape.size() == 1)) {
|
||||||
|
// Case 3:
|
||||||
|
// - Both arguments are 1-D
|
||||||
|
|
||||||
|
// Fill the output with value 0.
|
||||||
|
Value zeroIndex = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
rewriter.create<StoreOp>(loc, zero, alloc, zeroIndex);
|
||||||
|
|
||||||
|
// Iterate along the reduction dimension.
|
||||||
|
// Use a value from A.
|
||||||
|
std::vector<Value> reduceLoops;
|
||||||
|
std::vector<Value> optimizedReduceLoops;
|
||||||
|
Block *optimizationReduceBlock =
|
||||||
|
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
|
||||||
|
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
|
||||||
|
optimizedReduceLoops);
|
||||||
|
addDimensionToPack(rewriter, loc, reducePack, A, 0);
|
||||||
|
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
|
||||||
|
|
||||||
|
// No optimization
|
||||||
|
rewriter.setInsertionPointToEnd(optimizationReduceBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, reduceLoops);
|
||||||
|
|
||||||
|
// Insert instructions into the reduction KrnlIterateOp.
|
||||||
|
Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&reduceIterationBlock);
|
||||||
|
|
||||||
|
// Induction variables
|
||||||
|
SmallVector<Value, 4> loopKIVs;
|
||||||
|
// K
|
||||||
|
loopKIVs.emplace_back(reduceIterationBlock.getArgument(0));
|
||||||
|
|
||||||
|
// Matmul computation
|
||||||
|
auto loadedA = rewriter.create<LoadOp>(loc, A, loopKIVs);
|
||||||
|
auto loadedB = rewriter.create<LoadOp>(loc, B, loopKIVs);
|
||||||
|
auto loadedY = rewriter.create<LoadOp>(loc, alloc, zeroIndex);
|
||||||
|
if (elementType.isa<IntegerType>()) {
|
||||||
|
auto AB = rewriter.create<MulIOp>(loc, loadedA, loadedB);
|
||||||
|
auto accumulated = rewriter.create<AddIOp>(loc, loadedY, AB);
|
||||||
|
rewriter.create<StoreOp>(loc, accumulated, alloc, zeroIndex);
|
||||||
|
} else if (elementType.isa<FloatType>()) {
|
||||||
|
auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
|
||||||
|
auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
|
||||||
|
rewriter.create<StoreOp>(loc, accumulated, alloc, zeroIndex);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No scalar matrix multiplication.
|
||||||
|
llvm_unreachable("Unsupported scalar matrix multiplication.");
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXMatMulOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXMatMulOpLowering>(ctx);
|
||||||
|
}
|
|
@ -0,0 +1,307 @@
|
||||||
|
//===----- reduction.inc - Lowering Reduction Ops -------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Reduction Operators to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Identity values
|
||||||
|
template <>
|
||||||
|
float getIdentityValue<float, ONNXReduceMaxOp>(){
|
||||||
|
return (float)-std::numeric_limits<float>::infinity();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
int getIdentityValue<int, ONNXReduceMaxOp>(){
|
||||||
|
return std::numeric_limits<int>::min();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
float getIdentityValue<float, ONNXReduceMinOp>(){
|
||||||
|
return (float)std::numeric_limits<float>::infinity();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
int getIdentityValue<int, ONNXReduceMinOp>(){
|
||||||
|
return std::numeric_limits<int>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
float getIdentityValue<float, ONNXReduceProdOp>(){
|
||||||
|
return (float)1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
int getIdentityValue<int, ONNXReduceProdOp>(){
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
float getIdentityValue<float, ONNXReduceSumOp>(){
|
||||||
|
return (float)0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
int getIdentityValue<int, ONNXReduceSumOp>(){
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scalar ops
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXReduceProdOp> {
|
||||||
|
using FOp = MulFOp;
|
||||||
|
using IOp = MulIOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXReduceSumOp> {
|
||||||
|
using FOp = AddFOp;
|
||||||
|
using IOp = AddIOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXReduceMaxOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXReduceMaxOp>(Operation *op,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value lhs = operands[0];
|
||||||
|
Value rhs = operands[1];
|
||||||
|
Type element_type = lhs.getType();
|
||||||
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
auto max = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
||||||
|
return result;
|
||||||
|
} else if (element_type.isa<FloatType>()) {
|
||||||
|
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
||||||
|
return result;
|
||||||
|
} else {
|
||||||
|
emitError(loc, "unsupported element type");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXReduceMinOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXReduceMinOp>(Operation *op,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value lhs = operands[0];
|
||||||
|
Value rhs = operands[1];
|
||||||
|
Type element_type = lhs.getType();
|
||||||
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
auto min = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
||||||
|
return result;
|
||||||
|
} else if (element_type.isa<FloatType>()) {
|
||||||
|
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
||||||
|
return result;
|
||||||
|
} else {
|
||||||
|
emitError(loc, "unsupported element type");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ONNXReductionOp>
|
||||||
|
struct ONNXReductionOpLowering : public ConversionPattern {
|
||||||
|
ONNXReductionOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
/*
|
||||||
|
* Condition: reduction function must be associative and commutative.
|
||||||
|
*
|
||||||
|
* Example 1 (here, reduction function is `+`):
|
||||||
|
* Induction variables: (i0, i1, i2)
|
||||||
|
* axes = [0, 2]
|
||||||
|
* keepdims = true
|
||||||
|
* krnl.iterate() with (i0, i1, i2) {
|
||||||
|
* Y(0, i1, 0) += X(i0, i1, i2)
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* Example 2 (here, reduction function is `+`):
|
||||||
|
* Induction variables: (i0, i1, i2)
|
||||||
|
* axes = [0, 2]
|
||||||
|
* keepdims = false
|
||||||
|
* krnl.iterate() with (i0, i1, i2) {
|
||||||
|
* Y(i1) += X(i0, i1, i2)
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto memRefInType = operands[0].getType().cast<MemRefType>();
|
||||||
|
auto memRefInShape = memRefInType.getShape();
|
||||||
|
auto tensorOutType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
int64_t inRank = memRefInType.getRank();
|
||||||
|
int64_t outRank = tensorOutType.getRank();
|
||||||
|
|
||||||
|
// Get attributes
|
||||||
|
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
|
||||||
|
std::vector<int64_t> axes;
|
||||||
|
if (axisAttrs) {
|
||||||
|
for (auto axisAttr : axisAttrs.getValue()) {
|
||||||
|
int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
|
||||||
|
axis = axis >= 0 ? axis : (inRank + axis);
|
||||||
|
assert(axis >= -inRank && axis <= inRank - 1);
|
||||||
|
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
||||||
|
axes.push_back(axis);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (decltype(inRank) i = 0; i < inRank; ++i) {
|
||||||
|
axes.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// KeepDims
|
||||||
|
auto keepdims =
|
||||||
|
llvm::dyn_cast<ONNXReductionOp>(op).keepdims();
|
||||||
|
bool isKeepdims = (keepdims == 1) ? true : false;
|
||||||
|
|
||||||
|
// Get type information
|
||||||
|
auto memRefOutType = convertTensorToMemRef(tensorOutType);
|
||||||
|
auto memRefOutShape = memRefOutType.getShape();
|
||||||
|
auto elementOutType = memRefOutType.getElementType();
|
||||||
|
std::map<int64_t, int64_t> outInDimMap =
|
||||||
|
getReductionMapping(memRefInType, axes, isKeepdims);
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
if (hasAllConstantDimensions(memRefOutType)) {
|
||||||
|
alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc);
|
||||||
|
} else {
|
||||||
|
SmallVector<Value, 2> allocOperands;
|
||||||
|
for (decltype(outRank) i = 0; i < outRank; ++i) {
|
||||||
|
if (memRefOutShape[i] < 0) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, operands[0], outInDimMap[i]);
|
||||||
|
allocOperands.push_back(dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, memRefOutType, allocOperands);
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto *parentBlock = alloc.getDefiningOp()->getBlock();
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// There are two Krnl loops:
|
||||||
|
// - One to initialize the result memref, and
|
||||||
|
// - One to do reduction
|
||||||
|
|
||||||
|
// Define loops to initialize the result.
|
||||||
|
std::vector<Value> originalLoopsInit;
|
||||||
|
std::vector<Value> optimizedLoopsInit;
|
||||||
|
Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit,
|
||||||
|
optimizedLoopsInit, outRank);
|
||||||
|
|
||||||
|
// Iteration information
|
||||||
|
KrnlIterateOperandPack packInit(rewriter, originalLoopsInit,
|
||||||
|
optimizedLoopsInit);
|
||||||
|
for (decltype(outRank) i = 0; i < outRank; ++i) {
|
||||||
|
addDimensionToPack(rewriter, loc, packInit, alloc, i);
|
||||||
|
}
|
||||||
|
auto iterateOpInit = rewriter.create<KrnlIterateOp>(loc, packInit);
|
||||||
|
Block &iterationBlockInit = iterateOpInit.bodyRegion().front();
|
||||||
|
|
||||||
|
// Perform the insertions into the body of the initialization loop.
|
||||||
|
// No optimization
|
||||||
|
rewriter.setInsertionPointToEnd(optimizationBlockInit);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoopsInit);
|
||||||
|
|
||||||
|
// Insert instructions inside the KernelIterateOp body.
|
||||||
|
rewriter.setInsertionPointToStart(&iterationBlockInit);
|
||||||
|
|
||||||
|
// Handle the operation:
|
||||||
|
SmallVector<Value, 4> loopIVs;
|
||||||
|
for (auto arg : iterationBlockInit.getArguments()) {
|
||||||
|
loopIVs.push_back(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value identity;
|
||||||
|
if (elementOutType.isa<FloatType>()) {
|
||||||
|
identity = rewriter.create<ConstantOp>(
|
||||||
|
loc, FloatAttr::get(elementOutType,
|
||||||
|
getIdentityValue<float, ONNXReductionOp>()));
|
||||||
|
} else if (elementOutType.isa<IntegerType>()) {
|
||||||
|
identity = rewriter.create<ConstantOp>(
|
||||||
|
loc, IntegerAttr::get(elementOutType,
|
||||||
|
getIdentityValue<int, ONNXReductionOp>()));
|
||||||
|
} else {
|
||||||
|
emitError(loc, "unsupported element type");
|
||||||
|
}
|
||||||
|
rewriter.create<StoreOp>(loc, identity, alloc, loopIVs);
|
||||||
|
|
||||||
|
// Define an Krnl loop to do reduction.
|
||||||
|
rewriter.setInsertionPointAfter(iterateOpInit);
|
||||||
|
std::vector<Value> originalLoops, optimizedLoops;
|
||||||
|
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
||||||
|
optimizedLoops, inRank);
|
||||||
|
// Iteration information
|
||||||
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
|
for (decltype(inRank) i = 0; i < inRank; ++i) {
|
||||||
|
addDimensionToPack(rewriter, loc, pack, operands[0], i);
|
||||||
|
}
|
||||||
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
|
// Perform the insertions into the body of the reduction loop.
|
||||||
|
// No optimization
|
||||||
|
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
|
||||||
|
// Insert instructions inside the KernelIterateOp body.
|
||||||
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
|
// Handle the operation:
|
||||||
|
SmallVector<Value, 4> inLoopIVs, outLoopIVs;
|
||||||
|
auto args = iterationBlock.getArguments();
|
||||||
|
for (int i = 0; i < args.size(); ++i) {
|
||||||
|
inLoopIVs.push_back(args[i]);
|
||||||
|
}
|
||||||
|
Value zeroIndex = nullptr;
|
||||||
|
for (decltype(inRank) i = 0; i < outRank; ++i) {
|
||||||
|
if (outInDimMap.find(i) != outInDimMap.end()) {
|
||||||
|
outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]);
|
||||||
|
} else {
|
||||||
|
if (zeroIndex) {
|
||||||
|
outLoopIVs.push_back(zeroIndex);
|
||||||
|
} else {
|
||||||
|
zeroIndex = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
outLoopIVs.push_back(zeroIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value next, accumulated;
|
||||||
|
next = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
|
||||||
|
accumulated = rewriter.create<LoadOp>(loc, alloc, outLoopIVs);
|
||||||
|
accumulated = mapToLowerScalarOp<ONNXReductionOp>(
|
||||||
|
op, memRefOutType.getElementType(), {accumulated, next}, rewriter);
|
||||||
|
rewriter.create<StoreOp>(loc, accumulated, alloc, outLoopIVs);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXReductionOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXReductionOpLowering<mlir::ONNXReduceMaxOp>,
|
||||||
|
ONNXReductionOpLowering<mlir::ONNXReduceMinOp>,
|
||||||
|
ONNXReductionOpLowering<mlir::ONNXReduceProdOp>,
|
||||||
|
ONNXReductionOpLowering<mlir::ONNXReduceSumOp>>(ctx);
|
||||||
|
}
|
|
@ -0,0 +1,205 @@
|
||||||
|
//===----- softmax.inc - Softmax Op ---------------------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers ONNX softmax operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
|
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
// softmax(x) = let max_x = max(x) in
|
||||||
|
// let exp_x = exp(x - max_x) in
|
||||||
|
// let sum = sum(exp_x) in
|
||||||
|
// exp_x / sum
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
||||||
|
int64_t rank = tensorType.getRank();
|
||||||
|
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
|
||||||
|
axis = axis >= 0 ? axis : rank + axis;
|
||||||
|
assert(axis >= -rank && axis <= rank - 1);
|
||||||
|
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
auto elementType = memRefType.getElementType();
|
||||||
|
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||||
|
operands[0]);
|
||||||
|
|
||||||
|
// Shape of the result
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
|
||||||
|
// Insert allocations and deallocations for sum and max.
|
||||||
|
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
|
||||||
|
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||||
|
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||||
|
Value zero =
|
||||||
|
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
Value negInfinity = rewriter.create<ConstantOp>(
|
||||||
|
loc,
|
||||||
|
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
|
||||||
|
|
||||||
|
// Define loops.
|
||||||
|
std::vector<Value> originalLoops;
|
||||||
|
std::vector<Value> optimizedLoops;
|
||||||
|
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
||||||
|
optimizedLoops, rank);
|
||||||
|
|
||||||
|
// Coerce the input into a 2-D tensor. `axis` will be the coercing point.
|
||||||
|
// This coercing follows the softmax definition in ONNX:
|
||||||
|
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax
|
||||||
|
// Here, we create an outer loop and inner loop for handling the two
|
||||||
|
// dimensions. The outer loop is only created once `axis` is not zero.
|
||||||
|
|
||||||
|
// Define an outer loop with respect to axis.
|
||||||
|
std::vector<Value> outerLoops, optimizedOuterLoops;
|
||||||
|
outerLoops.reserve(axis);
|
||||||
|
optimizedOuterLoops.reserve(axis);
|
||||||
|
for (int i = 0; i < axis; ++i) {
|
||||||
|
outerLoops.push_back(originalLoops[i]);
|
||||||
|
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||||
|
}
|
||||||
|
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
|
||||||
|
for (int i = 0; i < axis; ++i)
|
||||||
|
addDimensionToPack(rewriter, loc, outerPack, operands[0], i);
|
||||||
|
|
||||||
|
// Define an inner loop with respect to axis.
|
||||||
|
std::vector<Value> innerLoops, optimizedInnerLoops;
|
||||||
|
innerLoops.reserve(rank - axis);
|
||||||
|
optimizedInnerLoops.reserve(rank - axis);
|
||||||
|
for (int i = axis; i < rank; ++i) {
|
||||||
|
innerLoops.push_back(originalLoops[i]);
|
||||||
|
optimizedInnerLoops.push_back(optimizedLoops[i]);
|
||||||
|
}
|
||||||
|
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
|
||||||
|
for (int i = axis; i < rank; ++i)
|
||||||
|
addDimensionToPack(rewriter, loc, innerPack, operands[0], i);
|
||||||
|
|
||||||
|
KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
|
||||||
|
SmallVector<Value, 4> outerLoopIVs;
|
||||||
|
if (axis != 0) {
|
||||||
|
outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
|
||||||
|
|
||||||
|
// No optimization
|
||||||
|
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
|
||||||
|
// Insert instructions inside the outer loop.
|
||||||
|
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
||||||
|
for (auto arg : outerIterationBlock.getArguments())
|
||||||
|
outerLoopIVs.push_back(arg);
|
||||||
|
|
||||||
|
// Reset accumulators.
|
||||||
|
rewriter.create<StoreOp>(loc, zero, sumOp);
|
||||||
|
rewriter.create<StoreOp>(loc, negInfinity, maxOp);
|
||||||
|
|
||||||
|
// Create an inner loop to compute max.
|
||||||
|
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
// Create an inner loop to compute sum.
|
||||||
|
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
// Create an inner loop to compute softmax.
|
||||||
|
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
} else {
|
||||||
|
// Reset accumulators.
|
||||||
|
rewriter.create<StoreOp>(loc, zero, sumOp);
|
||||||
|
rewriter.create<StoreOp>(loc, negInfinity, maxOp);
|
||||||
|
|
||||||
|
// Create an inner loop to compute max.
|
||||||
|
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
// Create an inner loop to compute sum.
|
||||||
|
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
// Create an inner loop to compute softmax.
|
||||||
|
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
|
||||||
|
// No optimization
|
||||||
|
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert instructions inside the max loop.
|
||||||
|
Block &maxIterationBlock = maxIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&maxIterationBlock);
|
||||||
|
|
||||||
|
// Get induction variables.
|
||||||
|
SmallVector<Value, 4> maxLoopIVs;
|
||||||
|
for (auto arg : outerLoopIVs)
|
||||||
|
maxLoopIVs.push_back(arg);
|
||||||
|
for (auto arg : maxIterationBlock.getArguments())
|
||||||
|
maxLoopIVs.push_back(arg);
|
||||||
|
|
||||||
|
// Compute the max value.
|
||||||
|
Value max = rewriter.create<LoadOp>(loc, maxOp);
|
||||||
|
Value nextMax = rewriter.create<LoadOp>(loc, operands[0], maxLoopIVs);
|
||||||
|
auto maxCond =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
|
||||||
|
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
|
||||||
|
rewriter.create<StoreOp>(loc, max, maxOp);
|
||||||
|
|
||||||
|
// Get the max.
|
||||||
|
rewriter.setInsertionPoint(sumIterateOp);
|
||||||
|
max = rewriter.create<LoadOp>(loc, maxOp);
|
||||||
|
|
||||||
|
// Insert instructions inside the sum loop.
|
||||||
|
Block &sumIterationBlock = sumIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&sumIterationBlock);
|
||||||
|
|
||||||
|
// Get induction variables.
|
||||||
|
SmallVector<Value, 4> sumLoopIVs;
|
||||||
|
for (auto arg : outerLoopIVs)
|
||||||
|
sumLoopIVs.push_back(arg);
|
||||||
|
for (auto arg : sumIterationBlock.getArguments())
|
||||||
|
sumLoopIVs.push_back(arg);
|
||||||
|
|
||||||
|
// Sum up values.
|
||||||
|
Value sum = rewriter.create<LoadOp>(loc, sumOp);
|
||||||
|
Value next = rewriter.create<LoadOp>(loc, operands[0], sumLoopIVs);
|
||||||
|
Value sub = rewriter.create<SubFOp>(loc, next, max);
|
||||||
|
Value exp = rewriter.create<ExpOp>(loc, sub);
|
||||||
|
sum = rewriter.create<AddFOp>(loc, sum, exp);
|
||||||
|
rewriter.create<StoreOp>(loc, sum, sumOp);
|
||||||
|
// Store intermediate values in the result to avoid recomputation.
|
||||||
|
rewriter.create<StoreOp>(loc, exp, alloc, sumLoopIVs);
|
||||||
|
|
||||||
|
// Get the sum.
|
||||||
|
rewriter.setInsertionPoint(softmaxIterateOp);
|
||||||
|
sum = rewriter.create<LoadOp>(loc, sumOp);
|
||||||
|
|
||||||
|
// Insert instructions inside the softmax loop.
|
||||||
|
Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&softmaxIterationBlock);
|
||||||
|
|
||||||
|
// Get induction variables.
|
||||||
|
SmallVector<Value, 4> softmaxLoopIVs;
|
||||||
|
for (auto arg : outerLoopIVs)
|
||||||
|
softmaxLoopIVs.push_back(arg);
|
||||||
|
for (auto arg : softmaxIterationBlock.getArguments())
|
||||||
|
softmaxLoopIVs.push_back(arg);
|
||||||
|
|
||||||
|
// Compute softmax.
|
||||||
|
Value expLoadedVal = rewriter.create<LoadOp>(loc, alloc, softmaxLoopIVs);
|
||||||
|
Value result = rewriter.create<DivFOp>(loc, expLoadedVal, sum);
|
||||||
|
rewriter.create<StoreOp>(loc, result, alloc, softmaxLoopIVs);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXSoftmaxOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXSoftmaxOpLowering>(ctx);
|
||||||
|
}
|
|
@ -0,0 +1,282 @@
|
||||||
|
//===----- conv.inc - Lowering Convolution Op -----------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Convolution Operators to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
|
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);
|
||||||
|
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||||
|
{operands[0]});
|
||||||
|
|
||||||
|
auto resultShape = memRefType.getShape();
|
||||||
|
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
||||||
|
auto kernelShape = operands[1].getType().cast<MemRefType>().getShape();
|
||||||
|
|
||||||
|
// R = ConvNoBias(D, K)
|
||||||
|
//
|
||||||
|
// The input/output shapes will look like this:
|
||||||
|
//
|
||||||
|
// D (NxCxHxW) x K (MxC/groupxKHxKW) -> R (NxMxRHxRW)
|
||||||
|
//
|
||||||
|
// M is a multiple of the number of groups:
|
||||||
|
// M = group * kernelsPerGroup
|
||||||
|
//
|
||||||
|
// The loop nest will look as follows:
|
||||||
|
//
|
||||||
|
// strides = [s1, s2]
|
||||||
|
//
|
||||||
|
// kernelsPerGroup = M / group;
|
||||||
|
// for n = 0 .. N:
|
||||||
|
// for g = 0 .. group:
|
||||||
|
// for m = 0 .. kernelsPerGroup:
|
||||||
|
// kernel = g * kernelsPerGroup + m;
|
||||||
|
// for r1 = 0 .. RH:
|
||||||
|
// for r2 = 0 .. RW:
|
||||||
|
// R[n][kernel][r1][r2] = 0;
|
||||||
|
// for c = 0 .. C/group:
|
||||||
|
// for k1 = 0 .. KH:
|
||||||
|
// for k2 = 0 .. KW:
|
||||||
|
// R[n][kernel][r1][r2] =
|
||||||
|
// D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] *
|
||||||
|
// K[kernel][c][k1][k2];
|
||||||
|
//
|
||||||
|
// Naming:
|
||||||
|
// n, g, m: outer loop nest indices
|
||||||
|
// r1, r2: spatial loop nest indices
|
||||||
|
// c, k1, k2: inner loop nest indices
|
||||||
|
//
|
||||||
|
// TODO: handle padding.
|
||||||
|
//
|
||||||
|
// In the general case:
|
||||||
|
//
|
||||||
|
// D (NxCxD1xD2x...xDdim) x K (MxC/groupxK1xK2x...xKdim)
|
||||||
|
// -> R (NxMxR1xR2x...xRdim)
|
||||||
|
//
|
||||||
|
// The above loop nest can be adapted by increasing the number
|
||||||
|
// of r- and k-index loop i.e. r1 r2 and k1 k2 loops.
|
||||||
|
|
||||||
|
// Set up outermost loops: n g m r1 r2 ... rdim
|
||||||
|
// Skip g if group is 1.
|
||||||
|
|
||||||
|
// Before we start the iteration we need to compute the number of
|
||||||
|
// unsplit kernels and fetch the number of groups from the attribute
|
||||||
|
// list. Group is always a compilation constant.
|
||||||
|
int64_t group = convOp.group().getSExtValue();
|
||||||
|
// Compute the number of unsplit kernels. The number of kernels
|
||||||
|
// must be a multiple of the number of groups.
|
||||||
|
int64_t kernelsPerGroup = floor(kernelShape[0] / group);
|
||||||
|
auto kernelsPerGroupValue =
|
||||||
|
rewriter.create<ConstantIndexOp>(loc, kernelsPerGroup);
|
||||||
|
auto zero = rewriter.create<ConstantOp>(
|
||||||
|
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
||||||
|
Value subchannels;
|
||||||
|
if (kernelShape[1] < 0) {
|
||||||
|
subchannels =
|
||||||
|
rewriter.create<DimOp>(loc, operands[1], 1).getResult();
|
||||||
|
} else {
|
||||||
|
subchannels = rewriter.create<ConstantIndexOp>(
|
||||||
|
loc, kernelShape[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Define outer loops and emit empty optimization block:
|
||||||
|
int64_t nOuterLoops = (group > 1) ? 3 : 2;
|
||||||
|
std::vector<Value> outerLoops;
|
||||||
|
std::vector<Value> optimizedOuterLoops;
|
||||||
|
Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops,
|
||||||
|
optimizedOuterLoops, nOuterLoops);
|
||||||
|
|
||||||
|
// Prepare iteration arguments over outer loop nest.
|
||||||
|
KrnlIterateOperandPack pack(
|
||||||
|
rewriter, outerLoops, optimizedOuterLoops);
|
||||||
|
// for n = 0 .. N:
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
if (inputShape[0] < 0)
|
||||||
|
pack.pushOperandBound(
|
||||||
|
rewriter.create<DimOp>(loc, operands[0], 0).getResult());
|
||||||
|
else
|
||||||
|
pack.pushConstantBound(inputShape[0]);
|
||||||
|
// for g = 0 .. N:
|
||||||
|
if (group > 1) {
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushConstantBound(group);
|
||||||
|
}
|
||||||
|
// for m = 0 .. kernelsPerGroup:
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushConstantBound(kernelsPerGroup);
|
||||||
|
// Outer loop iteration.
|
||||||
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
|
Block &outerIterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
// Emit optimizations for outer loops:
|
||||||
|
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, outerLoops);
|
||||||
|
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
||||||
|
{
|
||||||
|
// 2. Emit the body of the outer loop nest.
|
||||||
|
|
||||||
|
// 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
|
||||||
|
// If group is not set then the value of the kernel ID is
|
||||||
|
// identical to that of the loop over kernels.
|
||||||
|
Value kernel = outerIterationBlock.getArguments()[1];
|
||||||
|
if (group > 1) {
|
||||||
|
// Middle loop is over groups and third loop is over the
|
||||||
|
// kernel identifiers in the current group.
|
||||||
|
auto kernelsOffset = rewriter.create<MulIOp>(loc,
|
||||||
|
outerIterationBlock.getArguments()[1],
|
||||||
|
kernelsPerGroupValue);
|
||||||
|
kernel = rewriter.create<AddIOp>(loc, kernelsOffset,
|
||||||
|
outerIterationBlock.getArguments()[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2.2 Define spatial loops
|
||||||
|
int64_t nSpatialLoops = resultShape.size() - 2;
|
||||||
|
std::vector<Value> spatialLoops;
|
||||||
|
std::vector<Value> optimizedSpatialLoops;
|
||||||
|
Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops,
|
||||||
|
optimizedSpatialLoops, nSpatialLoops);
|
||||||
|
|
||||||
|
// 2.3 Prepare iteration arguments for spatial loop nest.
|
||||||
|
KrnlIterateOperandPack spatialPack(
|
||||||
|
rewriter, spatialLoops, optimizedSpatialLoops);
|
||||||
|
for (int i = 2; i < resultShape.size(); ++i)
|
||||||
|
addDimensionToPack(rewriter, loc, spatialPack, alloc, i);
|
||||||
|
|
||||||
|
// 2.4 Emit loop nest over output spatial dimensions.
|
||||||
|
// for rX = 0 .. RX
|
||||||
|
auto spatialIterateOp =
|
||||||
|
rewriter.create<KrnlIterateOp>(loc, spatialPack);
|
||||||
|
Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front();
|
||||||
|
// 2.5 Emit optimizations for outer loops:
|
||||||
|
rewriter.setInsertionPointToEnd(optSpatialLoopBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, spatialLoops);
|
||||||
|
rewriter.setInsertionPointToStart(&spatialIterationBlock);
|
||||||
|
{
|
||||||
|
// 3. Emit the body of the spatial loop nest.
|
||||||
|
// 3.1 Emit: R[n][kernel][r1][r2] = 0;
|
||||||
|
SmallVector<Value, 4> resultIndices;
|
||||||
|
// n
|
||||||
|
resultIndices.emplace_back(outerIterationBlock.getArguments()[0]);
|
||||||
|
// kernel
|
||||||
|
resultIndices.emplace_back(kernel);
|
||||||
|
// rX
|
||||||
|
for (auto arg : spatialIterationBlock.getArguments())
|
||||||
|
resultIndices.emplace_back(arg);
|
||||||
|
// Store initializer value into output location.
|
||||||
|
rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
|
||||||
|
|
||||||
|
// 3.2 Define inner loops.
|
||||||
|
int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
|
||||||
|
std::vector<Value> innerLoops;
|
||||||
|
std::vector<Value> optimizedInnerLoops;
|
||||||
|
Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops,
|
||||||
|
optimizedInnerLoops, nInnerLoops);
|
||||||
|
|
||||||
|
// 3.3 Prepare iteration arguments for inner loop nest.
|
||||||
|
KrnlIterateOperandPack innerPack(
|
||||||
|
rewriter, innerLoops, optimizedInnerLoops);
|
||||||
|
// for c = 0 .. C/group
|
||||||
|
innerPack.pushConstantBound(0);
|
||||||
|
innerPack.pushConstantBound(kernelShape[1]);
|
||||||
|
// for Kx = 0 .. KX
|
||||||
|
for (int i = 2; i < kernelShape.size(); ++i)
|
||||||
|
addDimensionToPack(rewriter, loc, innerPack, operands[1], i);
|
||||||
|
|
||||||
|
// 3.4 Emit inner loop nest.
|
||||||
|
auto innerIterateOp =
|
||||||
|
rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
Block &innerIterationBlock = innerIterateOp.bodyRegion().front();
|
||||||
|
// 3.5 Emit optimizations for outer loops:
|
||||||
|
rewriter.setInsertionPointToEnd(optInnerLoopBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, innerLoops);
|
||||||
|
rewriter.setInsertionPointToStart(&innerIterationBlock);
|
||||||
|
{
|
||||||
|
// 4. Emit inner loop body
|
||||||
|
// R[n][kernel][r1][r2] =
|
||||||
|
// D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] *
|
||||||
|
// K[kernel][c][k1][k2];
|
||||||
|
|
||||||
|
// 4.1 Prepare indices for accesing the data tensor.
|
||||||
|
SmallVector<Value, 4> dataIndices;
|
||||||
|
// n
|
||||||
|
dataIndices.emplace_back(outerIterationBlock.getArguments()[0]);
|
||||||
|
// g * (C / group) + c
|
||||||
|
Value channelDepth = innerIterationBlock.getArguments()[0];
|
||||||
|
if (group > 1)
|
||||||
|
channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
|
||||||
|
rewriter.create<MulIOp>(loc, subchannels,
|
||||||
|
outerIterationBlock.getArguments()[1]));
|
||||||
|
dataIndices.emplace_back(channelDepth);
|
||||||
|
// sX * rX + kX
|
||||||
|
auto stridesAttribute = convOp.stridesAttr();
|
||||||
|
// Read strides attribute
|
||||||
|
SmallVector<int, 4> strides;
|
||||||
|
if (stridesAttribute)
|
||||||
|
for (auto stride : stridesAttribute.getValue())
|
||||||
|
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
|
||||||
|
for (int i = 0; i < kernelShape.size() - 2; ++i) {
|
||||||
|
Value spatialIndex = spatialIterationBlock.getArguments()[i];
|
||||||
|
// If strides are present then emit the correct access index.
|
||||||
|
if (stridesAttribute && strides[i] > 1)
|
||||||
|
spatialIndex = rewriter.create<MulIOp>(loc,
|
||||||
|
rewriter.create<ConstantIndexOp>(loc, strides[i]),
|
||||||
|
spatialIterationBlock.getArguments()[i]);
|
||||||
|
dataIndices.emplace_back(
|
||||||
|
rewriter.create<AddIOp>(loc, spatialIndex,
|
||||||
|
innerIterationBlock.getArguments()[i+1]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4.2 Prepare indices for accessing the kernel tensor.
|
||||||
|
SmallVector<Value, 4> kernelIndices;
|
||||||
|
// kernel
|
||||||
|
kernelIndices.emplace_back(kernel);
|
||||||
|
// c
|
||||||
|
kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]);
|
||||||
|
// kX
|
||||||
|
for (int i = 0; i < kernelShape.size() - 2; ++i)
|
||||||
|
kernelIndices.emplace_back(
|
||||||
|
innerIterationBlock.getArguments()[i+1]);
|
||||||
|
|
||||||
|
// 4.3 Compute convolution.
|
||||||
|
auto loadData =
|
||||||
|
rewriter.create<LoadOp>(loc, operands[0], dataIndices);
|
||||||
|
auto loadKernel =
|
||||||
|
rewriter.create<LoadOp>(loc, operands[1], kernelIndices);
|
||||||
|
auto loadPartialSum =
|
||||||
|
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
||||||
|
Value result = rewriter.create<AddFOp>(loc, loadPartialSum,
|
||||||
|
rewriter.create<MulFOp>(loc, loadData, loadKernel));
|
||||||
|
// 4.4 Store computed value into output location.
|
||||||
|
rewriter.create<StoreOp>(loc, result, alloc, resultIndices);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXConvOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXConvNoBiasOpLowering>(ctx);
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
//===----- identity.inc - Lowering Identity Op ----------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Identity Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct ONNXIdentityOpLowering : public ConversionPattern {
|
||||||
|
ONNXIdentityOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
rewriter.replaceOp(op, operands[0]);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXIdentityOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXIdentityOpLowering>(ctx);
|
||||||
|
}
|
|
@ -0,0 +1,151 @@
|
||||||
|
//===----- reshape.inc - Lowering Reshape Op ------------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Reshape Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
|
ONNXReshapeOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
Value alloc;
|
||||||
|
|
||||||
|
// Compute size in bytes using the input tensor.
|
||||||
|
Value tensorSize = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
getMemRefEltSizeInBytes(memRefType)));
|
||||||
|
for (int i = 0; i < inputShape.size(); ++i) {
|
||||||
|
Value dimVal;
|
||||||
|
if (inputShape[i] < 0) {
|
||||||
|
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
|
||||||
|
dimVal =
|
||||||
|
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
|
||||||
|
} else {
|
||||||
|
dimVal = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
inputShape[i]));
|
||||||
|
}
|
||||||
|
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
if (hasAllConstantDimensions(memRefType)) {
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
} else {
|
||||||
|
// If a dimension is zero, the actual dimension value is taken from the
|
||||||
|
// input tensor.
|
||||||
|
//
|
||||||
|
// If the shape array has a negative dimension (-1), we compute its actual
|
||||||
|
// dimension value from the other dimensions. But we don't have enough
|
||||||
|
// information about the other dimensions at this point. So, we need to
|
||||||
|
// scan the shape first to calculate reduction of all of the dimensions.
|
||||||
|
// If the reduction is negative, then the shape array contains a negative
|
||||||
|
// dimension. Otherwise, the reduction is the same as the one computed
|
||||||
|
// from the input tensor.
|
||||||
|
Value tensorSizeFromShape = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
getMemRefEltSizeInBytes(memRefType)));
|
||||||
|
SmallVector<Value, 4> DimInfo;
|
||||||
|
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||||
|
Value index = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||||
|
// Load index from array of indices.
|
||||||
|
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||||
|
// If a dimension is zero, the actual dimension value is taken from the
|
||||||
|
// input tensor.
|
||||||
|
//
|
||||||
|
// If a dimension is negative, it is computed from the other dimensions.
|
||||||
|
// But we don't have enough information about the other dimensions at
|
||||||
|
// this point. So, we let it as it is (-1), and compute it later.
|
||||||
|
if (i < inputShape.size()) {
|
||||||
|
Value dimVal;
|
||||||
|
auto loadedValType = loadedVal.getType().cast<IntegerType>();
|
||||||
|
if (inputShape[i] < 0) {
|
||||||
|
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
|
||||||
|
dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType);
|
||||||
|
} else {
|
||||||
|
dimVal = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(loadedValType, inputShape[i]));
|
||||||
|
}
|
||||||
|
auto zero = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(loadedValType, 0));
|
||||||
|
auto isZero =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, loadedVal, zero);
|
||||||
|
loadedVal = rewriter.create<SelectOp>(loc, isZero, dimVal, loadedVal);
|
||||||
|
}
|
||||||
|
// Check if the loaded index is already the correct width of 64 bits.
|
||||||
|
// Convert the value to a 64 bit integer if needed.
|
||||||
|
Value int64LoadedVal = loadedVal;
|
||||||
|
if (loadedVal.getType().cast<IntegerType>().getWidth() < 64)
|
||||||
|
int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
||||||
|
loc, loadedVal, rewriter.getIntegerType(64));
|
||||||
|
tensorSizeFromShape =
|
||||||
|
rewriter.create<MulIOp>(loc, tensorSizeFromShape, int64LoadedVal);
|
||||||
|
// Store intermediate results to use later.
|
||||||
|
DimInfo.emplace_back(int64LoadedVal);
|
||||||
|
}
|
||||||
|
// Reverse tensorSizeFromShape since it is negative if the shape array has
|
||||||
|
// a negative dimension. This is safe since we only use it to compute the
|
||||||
|
// actual value for the negative dimension.
|
||||||
|
auto zero = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
tensorSizeFromShape =
|
||||||
|
rewriter.create<SubIOp>(loc, zero, tensorSizeFromShape);
|
||||||
|
|
||||||
|
// Obtain operands for AllocOp.
|
||||||
|
SmallVector<Value, 4> allocOperands;
|
||||||
|
auto negOne = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1));
|
||||||
|
|
||||||
|
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||||
|
auto dimVal = DimInfo[i];
|
||||||
|
auto isNegOne =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dimVal, negOne);
|
||||||
|
// If dimension is negative, compute its value from the other
|
||||||
|
// dimensions.
|
||||||
|
auto actualDimVal =
|
||||||
|
rewriter.create<SignedDivIOp>(loc, tensorSize, tensorSizeFromShape);
|
||||||
|
auto loadedVal =
|
||||||
|
rewriter.create<SelectOp>(loc, isNegOne, actualDimVal, dimVal);
|
||||||
|
allocOperands.push_back(rewriter.create<IndexCastOp>(
|
||||||
|
loc, loadedVal, rewriter.getIndexType()));
|
||||||
|
}
|
||||||
|
AllocOp allocateMemref =
|
||||||
|
rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||||
|
|
||||||
|
// Make sure to allocate at the beginning of the block if
|
||||||
|
// all dimensions are known.
|
||||||
|
auto *parentBlock = allocateMemref.getOperation()->getBlock();
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, allocateMemref);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
|
||||||
|
alloc = allocateMemref;
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXReshapeOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXReshapeOpLowering>(ctx);
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
//===----- transpose.inc - Lowering Transpose Op --------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Transpose Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct ONNXTransposeOpLowering : public ConversionPattern {
|
||||||
|
ONNXTransposeOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||||
|
{operands[0]});
|
||||||
|
|
||||||
|
// Number of loops
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
int64_t rank = memRefShape.size();
|
||||||
|
|
||||||
|
// Define loops.
|
||||||
|
std::vector<Value> originalLoops;
|
||||||
|
std::vector<Value> optimizedLoops;
|
||||||
|
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
||||||
|
optimizedLoops, rank);
|
||||||
|
|
||||||
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
|
// Iterate over the loop nest using the input shape.
|
||||||
|
for (int i = 0; i < rank; ++i)
|
||||||
|
addDimensionToPack(rewriter, loc, pack, operands[0], i);
|
||||||
|
|
||||||
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
|
// Now perform the insertions into the body of the
|
||||||
|
// just generated instructions:
|
||||||
|
|
||||||
|
// 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
|
||||||
|
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||||
|
// Return from KrnlOptimizeLoopsOp body.
|
||||||
|
// When no optimizations are present we just return the loops
|
||||||
|
// unchaged.
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
|
||||||
|
// 2. Insert instructions inside the KernelIterateOp body.
|
||||||
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
|
// Handle the operation.
|
||||||
|
|
||||||
|
// Read perm attribute.
|
||||||
|
SmallVector<int, 4> perm;
|
||||||
|
auto permAttribute = llvm::dyn_cast<ONNXTransposeOp>(op).permAttr();
|
||||||
|
if (permAttribute) {
|
||||||
|
for (auto permVal : permAttribute.getValue())
|
||||||
|
perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
|
||||||
|
} else {
|
||||||
|
// TODO: Remove when perm is guaranteed to be present (even for
|
||||||
|
// the default case). This means that perm was added by shape
|
||||||
|
// inference or another pass to contain the values corresponding
|
||||||
|
// to the default behavior of Transpose.
|
||||||
|
for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--)
|
||||||
|
perm.emplace_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> inLoopIVs;
|
||||||
|
for (auto arg : iterationBlock.getArguments())
|
||||||
|
inLoopIVs.emplace_back(arg);
|
||||||
|
|
||||||
|
SmallVector<Value, 4> outLoopIVs;
|
||||||
|
for (int i=0; i<iterationBlock.getArguments().size(); ++i)
|
||||||
|
outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]);
|
||||||
|
|
||||||
|
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
|
||||||
|
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXTransposeOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXTransposeOpLowering>(ctx);
|
||||||
|
}
|
|
@ -0,0 +1,86 @@
|
||||||
|
//===----- unsqueeze.inc - Lowering Unsqueeze Op --------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Unsqueeze Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||||
|
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
int outRank = tensorType.getRank();
|
||||||
|
|
||||||
|
// Assume that `axes` has been validated by shape inference.
|
||||||
|
// So, here we just get it.
|
||||||
|
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXUnsqueezeOp>(op).axesAttr();
|
||||||
|
SmallVector<int, 4> axes;
|
||||||
|
for (auto axisAttr : axisAttrs.getValue()) {
|
||||||
|
int axis = axisAttr.cast<IntegerAttr>().getInt();
|
||||||
|
axis = axis >= 0 ? axis : (outRank + axis);
|
||||||
|
axes.emplace_back(axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
Value alloc;
|
||||||
|
|
||||||
|
// Compute size in bytes.
|
||||||
|
Value tensorSize = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
getMemRefEltSizeInBytes(memRefType)));
|
||||||
|
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
if (hasAllConstantDimensions(memRefType)) {
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||||
|
Value dimVal = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
memRefShape[i]));
|
||||||
|
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown dimensions are always the operand's dimensions.
|
||||||
|
SmallVector<Value, 4> allocOperands;
|
||||||
|
for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) {
|
||||||
|
Value dimVal = nullptr;
|
||||||
|
if (memRefShape[outIdx] < 0) {
|
||||||
|
Value index = rewriter.create<DimOp>(loc, operands[0], inIdx);
|
||||||
|
dimVal = rewriter.create<IndexCastOp>(
|
||||||
|
loc, index, rewriter.getIntegerType(64));
|
||||||
|
allocOperands.emplace_back(index);
|
||||||
|
} else {
|
||||||
|
dimVal = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
memRefShape[outIdx]));
|
||||||
|
}
|
||||||
|
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||||
|
if (std::find(axes.begin(), axes.end(), outIdx) == axes.end())
|
||||||
|
inIdx++;
|
||||||
|
}
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||||
|
auto *parentBlock = alloc.getDefiningOp()->getBlock();
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXUnsqueezeOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXUnsqueezeOpLowering>(ctx);
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue