diff --git a/src/Transform/ONNX/ONNXRewrite.cpp b/src/Transform/ONNX/ONNXRewrite.cpp index d7619f0..afb4030 100644 --- a/src/Transform/ONNX/ONNXRewrite.cpp +++ b/src/Transform/ONNX/ONNXRewrite.cpp @@ -70,103 +70,6 @@ ArrayAttr insertZerosForNonPaddedDims( /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Transform/ONNX/ONNXRewrite.inc" -//===----------------------------------------------------------------------===// -// Rewrite: -// %0 = onnx.Conv(%D : tensor, %K) -// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} -> -// tensor -// -// as: -// %0 = onnx.PadConstantValuePasOp(%D) -// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} -> -// tensor -// %1 = onnx.Conv(%0 : tensor, %K) {pads = [0, ..., 0]} -> -// tensor -//===----------------------------------------------------------------------===// -struct SplitConvOpPattern : public RewritePattern { - SplitConvOpPattern(MLIRContext *context) - : RewritePattern(ONNXConvOp::getOperationName(), - {ONNXPadConstantValuePadOp::getOperationName(), - ONNXConvOp::getOperationName()}, - 1, context) {} - - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - - // If convolution does not use padding then no rewrite is required. - ONNXConvOp convOp = llvm::dyn_cast(op); - auto padsAttribute = convOp.padsAttr(); - if (!padsAttribute) - return matchFailure(); - - // If auto_pad is VALID then no padding happens and no rewrite isrequired. - auto autoPad = convOp.auto_pad(); - if (autoPad == "VALID") - return matchFailure(); - - auto data = op->getOperands()[0]; - auto inputShape = data.getType().cast().getShape(); - - // Dimensionality of the input: - // inputRank - // |----------------------| - // D : (N x C x D1 x D2 x ... DK) - // |______________| - // inputDims - // - int64_t inputRank = inputShape.size(); - int64_t inputDims = inputRank - 2; - - // If all pads values are equal to zero then no rewrite is required. - bool allZeros = true; - for (auto padsValue : padsAttribute.getValue()) { - if (padsValue.cast().getInt() > 0) { - allZeros = false; - break; - } - } - - if (allZeros) - return matchFailure(); - - // Create padding vector for the explicit padding op attribute. - SmallVector pads(2 * inputRank, 0); - SmallVector outPaddedShape(inputRank, 0); - outPaddedShape[0] = inputShape[0]; - outPaddedShape[1] = inputShape[1]; - for (int i = 0; i < inputDims; ++i) { - int64_t beginPad = - padsAttribute.getValue()[i].cast().getInt(); - int64_t endPad = - padsAttribute.getValue()[inputDims + i].cast().getInt(); - pads[i + 2] = beginPad; - pads[inputRank + i + 2] = endPad; - outPaddedShape[i + 2] += beginPad + inputShape[i + 2] + endPad; - } - - // Create padding operation. - auto inputElemType = data.getType().cast().getElementType(); - ONNXPadConstantValuePadOp paddingOp = - rewriter.create( - loc, RankedTensorType::get(outPaddedShape, inputElemType), data, - rewriter.getI64ArrayAttr(pads), FloatAttr::get(inputElemType, 0), - StringAttr::get("constant", loc->getContext())); - - SmallVector newConvPads(2 * inputDims, 0); - auto tensorType = (*op->result_type_begin()).cast(); - ONNXConvOp newConvOp = rewriter.create( - loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1], - convOp.getOperands()[2], - convOp.auto_padAttr(), convOp.dilationsAttr(), - convOp.groupAttr(), convOp.kernel_shapeAttr(), - rewriter.getI64ArrayAttr(newConvPads), - convOp.stridesAttr()); - - rewriter.replaceOp(op, newConvOp.getResult()); - return matchSuccess(); - }; -}; } // end anonymous namespace /// on the ONNXMaxPoolSingleOutOp. @@ -177,5 +80,5 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns( /// on the ONNXConvOp. void ONNXConvOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } diff --git a/src/Transform/ONNX/ONNXRewrite.td b/src/Transform/ONNX/ONNXRewrite.td index f0684f9..8e0a559 100644 --- a/src/Transform/ONNX/ONNXRewrite.td +++ b/src/Transform/ONNX/ONNXRewrite.td @@ -62,6 +62,10 @@ class insertZerosForNonPaddedDims: def HasNonZeroInArrayAttr: Constraint, "has non-zero elements">; +// Check that a StrAttr does not contain a specific value. +class IsNotStringAttrOfValue: + Constraint().getValue() != \"" # val # "\"">>; + //===----------------------------------------------------------------------===// // Rewrite: // %0 = onnx.MaxPoolSingleOutOp(%D : tensor) @@ -90,7 +94,38 @@ def MaxPoolSingleOutOpPaddingPattern: Pat< $auto_pad, $ceil_mode, $dilation, $kernel_shape, (createArrayAttrOfZerosFrom $pads), $storage_order, $strides), - [(HasNonZeroInArrayAttr:$pads)] + [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)] +>; + +//===----------------------------------------------------------------------===// +// Rewrite: +// %0 = onnx.ConvOp(%D : tensor, %K) +// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} -> +// tensor +// +// as: +// %0 = onnx.PadConstantValuePadOp(%D) +// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} -> +// tensor +// %1 = onnx.Conv(%0 : tensor, %K) {pads = [0, ..., 0]} -> +// tensor +//===----------------------------------------------------------------------===// + +def ConvOpPaddingPattern: Pat< + (ONNXConvOp:$res + $x, + $w, $b, $auto_pad, $dilation, $group, $kernel_shape, + $pads, + $strides), + (ONNXConvOp + (ONNXPadConstantValuePadOp $x, + (insertZerosForNonPaddedDims<2> $pads), + (FloatAttrOfValue<0> $res), + (StringAttrOfValue<"constant">)), + $w, $b, $auto_pad, $dilation, $group, $kernel_shape, + (createArrayAttrOfZerosFrom $pads), + $strides), + [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)] >; #endif // ONNX_REWRITE