//===------PadConstantValuePad.cpp - Lowering PadConstantValuePad Op ------===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file lowers the ONNX PadConstantValuePad Operator to Krnl dialect. // //===----------------------------------------------------------------------===// #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" using namespace mlir; struct ONNXPadConstantValuePadOpLowering : public ConversionPattern { ONNXPadConstantValuePadOpLowering(MLIRContext *ctx) : ConversionPattern( mlir::ONNXPadConstantValuePadOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto tensorType = (*op->result_type_begin()); ONNXPadConstantValuePadOpOperandAdaptor operandAdaptor(operands); auto loc = op->getLoc(); // Only constant padding is supported now. auto padMode = llvm::dyn_cast(op).mode(); if (padMode != "constant") return emitError(loc, "unsupported mode for PadConstantValuePad"); auto constantValAttr = llvm::dyn_cast(op).constant_valueAttr(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertToMemRefType(tensorType); Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else return emitError(loc, "unexpected output has non-Constant shape"); // Number of loops auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // Iterate over the loop nest using the output shape. BuildKrnlLoop padLoops(rewriter, loc, rank); padLoops.createDefineAndOptimizeOp(); for (int i = 0; i < rank; ++i) padLoops.pushBounds(0, alloc, i); padLoops.createIterateOp(); // Iterate over the loop nest using the input shape. BuildKrnlLoop valueLoops(rewriter, loc, rank); valueLoops.createDefineAndOptimizeOp(); for (int i = 0; i < rank; ++i) valueLoops.pushBounds(0, operandAdaptor.data(), i); valueLoops.createIterateOp(); // Copy the input data into the output. rewriter.setInsertionPointToStart(valueLoops.getIterateBlock()); SmallVector inLoopIVs; for (int i = 0; i < rank; ++i) inLoopIVs.emplace_back(valueLoops.getInductionVar(i)); auto pads = llvm::dyn_cast(op).pads(); SmallVector pad_begin; for (int i = 0; i < pads.size() / 2; ++i) { pad_begin.emplace_back(pads.getValue()[i].cast().getInt()); } SmallVector outLoopIVs; for (int i = 0; i < rank; ++i) { // Calculate the index for the load and store. if (pad_begin[i] == 0) { outLoopIVs.emplace_back(valueLoops.getInductionVar(i)); } else { auto outIV = rewriter.create(loc, rewriter.create(loc, pad_begin[i]), valueLoops.getInductionVar(i)); outLoopIVs.emplace_back(outIV); } } auto inVal = rewriter.create(loc, operandAdaptor.data(), inLoopIVs); rewriter.create(loc, inVal, alloc, outLoopIVs); rewriter.setInsertionPointToStart(padLoops.getIterateBlock()); SmallVector outLoopIVs1; for (int i = 0; i < rank; ++i) outLoopIVs1.emplace_back(padLoops.getInductionVar(i)); auto inVal1 = rewriter.create(loc, constantValAttr); rewriter.create(loc, inVal1, alloc, outLoopIVs1); // Replace the original op with the generated code. rewriter.replaceOp(op, alloc); return success(); } }; void populateLoweringONNXPadConstantValuePadOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert(ctx); }