onnx-mlir/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp

86 lines
3.2 KiB
C++

//===---------------- Concat.cpp - Lowering Concat Op -------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Concat Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
using namespace mlir;
struct ONNXConcatOpLowering : public ConversionPattern {
ONNXConcatOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConcatOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Gather info.
auto loc = op->getLoc();
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
ONNXConcatOp concatOp = llvm::dyn_cast<ONNXConcatOp>(op);
auto axis = concatOp.axis().getSExtValue();
int inputNum = operands.size();
// Alloc and dealloc.
auto resultOperand = concatOp.concat_result();
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto resultShape = memRefType.getShape();
auto rank = resultShape.size();
assert((axis >= 0 && axis < rank) && "Concat axis out of bounds");
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {resultOperand});
// Creates loops, one for each input.
int writeOffset = 0;
for (int i = 0; i < inputNum; ++i) {
OpBuilder::InsertionGuard insertGuard(rewriter);
// Operand info.
auto currShape = operands[i].getType().cast<MemRefType>().getShape();
// Create loop.
BuildKrnlLoop inputLoops(rewriter, loc, rank);
inputLoops.createDefineOp();
for (int r = 0; r < rank; ++r)
inputLoops.pushBounds(0, operands[i], r);
inputLoops.createIterateOp();
rewriter.setInsertionPointToStart(inputLoops.getIterateBlock());
// Indices for the read and write.
SmallVector<Value, 4> readIndices;
SmallVector<Value, 4> writeIndices;
for (int r = 0; r < rank; ++r) {
readIndices.emplace_back(inputLoops.getInductionVar(r));
if (r != axis || writeOffset == 0) {
writeIndices.emplace_back(inputLoops.getInductionVar(r));
} else {
AffineMap indexWithOffsetMap =
AffineMap::get(1, 0, rewriter.getAffineDimExpr(0) + writeOffset);
Value indexWithOffset =
rewriter.create<AffineApplyOp>(loc, indexWithOffsetMap,
ArrayRef<Value>{inputLoops.getInductionVar(r)});
writeIndices.emplace_back(indexWithOffset);
}
}
// Insert copy.
auto loadData =
rewriter.create<AffineLoadOp>(loc, operands[i], readIndices);
rewriter.create<AffineStoreOp>(loc, loadData, alloc, writeIndices);
// Increment offset
writeOffset += currShape[axis];
}
rewriter.replaceOp(op, alloc);
return success();
}
};
void populateLoweringONNXConcatOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXConcatOpLowering>(ctx);
}