//===-----------------------Pad.cpp - Lowering Pad Op -------------------===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file lowers the ONNX Pad Operator to Krnl dialect. // //===----------------------------------------------------------------------===// #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" using namespace mlir; struct ONNXPadOpLowering : public ConversionPattern { ONNXPadOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXPadOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { ONNXPadOp myOp = llvm::dyn_cast(op); ONNXPadOpOperandAdaptor operandAdaptor(operands); auto tensorType = myOp.output().getType(); auto loc = op->getLoc(); // Only constant padding is supported now. auto padMode = myOp.mode(); if (padMode != "constant") return emitError(loc, "unsupported mode for Pad"); DenseElementsAttr constantValAttr = myOp.getAttr("constant_value") .dyn_cast_or_null(); if (!constantValAttr) return emitError(loc, "unsupported value"); DenseElementsAttr padsAttributes = myOp.getAttr("pads").dyn_cast_or_null(); if (!padsAttributes) return emitError(loc, "Pad: unknown pads"); auto memRefType = convertToMemRefType(tensorType); Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else return emitError(loc, "unexpected output has non-Constant shape"); // Number of loops auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // get the padding vector into a temporary smallvector SmallVector pads(rank * 2, -1); auto padsIt = padsAttributes.getValues().begin(); for (int i = 0; i < rank * 2; ++i) pads[i] = (*padsIt++).cast().getInt(); // get the padding value auto valueAttr = (*constantValAttr.getValues().begin()); // Iterate over the loop nest using the output shape. BuildKrnlLoop padLoops(rewriter, loc, rank); padLoops.createDefineAndOptimizeOp(); for (int i = 0; i < rank; ++i) padLoops.pushBounds(0, alloc, i); padLoops.createIterateOp(); // Iterate over the loop nest using the input shape. BuildKrnlLoop valueLoops(rewriter, loc, rank); valueLoops.createDefineAndOptimizeOp(); for (int i = 0; i < rank; ++i) valueLoops.pushBounds(0, operandAdaptor.data(), i); valueLoops.createIterateOp(); // Copy the input data into the output. rewriter.setInsertionPointToStart(valueLoops.getIterateBlock()); SmallVector inLoopIVs; for (int i = 0; i < rank; ++i) inLoopIVs.emplace_back(valueLoops.getInductionVar(i)); SmallVector outLoopIVs; for (int i = 0; i < rank; ++i) { // Calculate the index for the load and store. if (pads[i] == 0) { outLoopIVs.emplace_back(valueLoops.getInductionVar(i)); } else { auto outIV = rewriter.create(loc, rewriter.create(loc, pads[i]), valueLoops.getInductionVar(i)); outLoopIVs.emplace_back(outIV); } } auto originValue = rewriter.create(loc, operandAdaptor.data(), inLoopIVs); rewriter.create(loc, originValue, alloc, outLoopIVs); rewriter.setInsertionPointToStart(padLoops.getIterateBlock()); SmallVector outLoopIVs1; for (int i = 0; i < rank; ++i) outLoopIVs1.emplace_back(padLoops.getInductionVar(i)); auto paddingValue = rewriter.create(loc, valueAttr); rewriter.create(loc, paddingValue, alloc, outLoopIVs1); // Replace the original op with the generated code. rewriter.replaceOp(op, alloc); return success(); } }; void populateLoweringONNXPadOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert(ctx); }