//===---------------- Concat.cpp - Lowering Concat Op -------------------===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file lowers the ONNX Concat Operator to Krnl dialect. // //===----------------------------------------------------------------------===// #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" using namespace mlir; struct ONNXConcatOpLowering : public ConversionPattern { ONNXConcatOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXConcatOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Gather info. auto loc = op->getLoc(); Value alloc; bool insertDealloc = checkInsertDealloc(op); ONNXConcatOp concatOp = llvm::dyn_cast(op); auto axis = concatOp.axis().getSExtValue(); int inputNum = operands.size(); // Alloc and dealloc. auto resultOperand = concatOp.concat_result(); auto memRefType = convertToMemRefType(*op->result_type_begin()); auto resultShape = memRefType.getShape(); auto rank = resultShape.size(); assert((axis >= 0 && axis < rank) && "Concat axis out of bounds"); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc( memRefType, loc, rewriter, insertDealloc, {resultOperand}); // Creates loops, one for each input. int writeOffset = 0; for (int i = 0; i < inputNum; ++i) { OpBuilder::InsertionGuard insertGuard(rewriter); // Operand info. auto currShape = operands[i].getType().cast().getShape(); // Create loop. BuildKrnlLoop inputLoops(rewriter, loc, rank); inputLoops.createDefineOp(); for (int r = 0; r < rank; ++r) inputLoops.pushBounds(0, operands[i], r); inputLoops.createIterateOp(); rewriter.setInsertionPointToStart(inputLoops.getIterateBlock()); // Indices for the read and write. SmallVector readIndices; SmallVector writeIndices; for (int r = 0; r < rank; ++r) { readIndices.emplace_back(inputLoops.getInductionVar(r)); if (r != axis || writeOffset == 0) { writeIndices.emplace_back(inputLoops.getInductionVar(r)); } else { AffineMap indexWithOffsetMap = AffineMap::get(1, 0, rewriter.getAffineDimExpr(0) + writeOffset); Value indexWithOffset = rewriter.create(loc, indexWithOffsetMap, ArrayRef{inputLoops.getInductionVar(r)}); writeIndices.emplace_back(indexWithOffset); } } // Insert copy. auto loadData = rewriter.create(loc, operands[i], readIndices); rewriter.create(loc, loadData, alloc, writeIndices); // Increment offset writeOffset += currShape[axis]; } rewriter.replaceOp(op, alloc); return success(); } }; void populateLoweringONNXConcatOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert(ctx); }