Replace SplitConvOpPattern by a declarative rewriting rule (#46)
This commit is contained in:
		
							parent
							
								
									bab2241b20
								
							
						
					
					
						commit
						867406191f
					
				|  | @ -70,103 +70,6 @@ ArrayAttr insertZerosForNonPaddedDims( | |||
| /// Include the patterns defined in the Declarative Rewrite framework.
 | ||||
| #include "src/Transform/ONNX/ONNXRewrite.inc" | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Rewrite:
 | ||||
| // %0 = onnx.Conv(%D : tensor<DShape>, %K)
 | ||||
| //     {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
 | ||||
| //         tensor<OutShape>
 | ||||
| //
 | ||||
| // as:
 | ||||
| // %0 = onnx.PadConstantValuePasOp(%D)
 | ||||
| //     {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
 | ||||
| //     tensor<DPaddedShape>
 | ||||
| // %1 = onnx.Conv(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} ->
 | ||||
| //     tensor<OutShape>
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| struct SplitConvOpPattern : public RewritePattern { | ||||
|   SplitConvOpPattern(MLIRContext *context) | ||||
|       : RewritePattern(ONNXConvOp::getOperationName(), | ||||
|                        {ONNXPadConstantValuePadOp::getOperationName(), | ||||
|                         ONNXConvOp::getOperationName()}, | ||||
|                        1, context) {} | ||||
| 
 | ||||
|   PatternMatchResult matchAndRewrite(Operation *op, | ||||
|       PatternRewriter &rewriter) const override { | ||||
|     auto loc = op->getLoc(); | ||||
| 
 | ||||
|     // If convolution does not use padding then no rewrite is required.
 | ||||
|     ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op); | ||||
|     auto padsAttribute = convOp.padsAttr(); | ||||
|     if (!padsAttribute) | ||||
|       return matchFailure(); | ||||
| 
 | ||||
|     // If auto_pad is VALID then no padding happens and no rewrite isrequired.
 | ||||
|     auto autoPad = convOp.auto_pad(); | ||||
|     if (autoPad == "VALID") | ||||
|       return matchFailure(); | ||||
| 
 | ||||
|     auto data = op->getOperands()[0]; | ||||
|     auto inputShape = data.getType().cast<TensorType>().getShape(); | ||||
| 
 | ||||
|     // Dimensionality of the input:
 | ||||
|     //              inputRank
 | ||||
|     //      |----------------------|
 | ||||
|     // D : (N x C x D1 x D2 x ... DK)
 | ||||
|     //              |______________|
 | ||||
|     //                  inputDims
 | ||||
|     //
 | ||||
|     int64_t inputRank = inputShape.size(); | ||||
|     int64_t inputDims = inputRank - 2; | ||||
| 
 | ||||
|     // If all pads values are equal to zero then no rewrite is required.
 | ||||
|     bool allZeros = true; | ||||
|     for (auto padsValue : padsAttribute.getValue()) { | ||||
|       if (padsValue.cast<IntegerAttr>().getInt() > 0) { | ||||
|         allZeros = false; | ||||
|         break; | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     if (allZeros) | ||||
|       return matchFailure(); | ||||
| 
 | ||||
|     // Create padding vector for the explicit padding op attribute.
 | ||||
|     SmallVector<int64_t, 4> pads(2 * inputRank, 0); | ||||
|     SmallVector<int64_t, 4> outPaddedShape(inputRank, 0); | ||||
|     outPaddedShape[0] = inputShape[0]; | ||||
|     outPaddedShape[1] = inputShape[1]; | ||||
|     for (int i = 0; i < inputDims; ++i) { | ||||
|       int64_t beginPad = | ||||
|           padsAttribute.getValue()[i].cast<IntegerAttr>().getInt(); | ||||
|       int64_t endPad = | ||||
|           padsAttribute.getValue()[inputDims + i].cast<IntegerAttr>().getInt(); | ||||
|       pads[i + 2] = beginPad; | ||||
|       pads[inputRank + i + 2] = endPad; | ||||
|       outPaddedShape[i + 2] += beginPad + inputShape[i + 2] + endPad; | ||||
|     } | ||||
| 
 | ||||
|     // Create padding operation.
 | ||||
|     auto inputElemType = data.getType().cast<TensorType>().getElementType(); | ||||
|     ONNXPadConstantValuePadOp paddingOp = | ||||
|         rewriter.create<ONNXPadConstantValuePadOp>( | ||||
|             loc, RankedTensorType::get(outPaddedShape, inputElemType), data, | ||||
|             rewriter.getI64ArrayAttr(pads), FloatAttr::get(inputElemType, 0), | ||||
|             StringAttr::get("constant", loc->getContext())); | ||||
| 
 | ||||
|     SmallVector<int64_t, 4> newConvPads(2 * inputDims, 0); | ||||
|     auto tensorType = (*op->result_type_begin()).cast<TensorType>(); | ||||
|     ONNXConvOp newConvOp = rewriter.create<ONNXConvOp>( | ||||
|             loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1], | ||||
|             convOp.getOperands()[2], | ||||
|             convOp.auto_padAttr(), convOp.dilationsAttr(), | ||||
|             convOp.groupAttr(), convOp.kernel_shapeAttr(), | ||||
|             rewriter.getI64ArrayAttr(newConvPads), | ||||
|             convOp.stridesAttr()); | ||||
| 
 | ||||
|     rewriter.replaceOp(op, newConvOp.getResult()); | ||||
|     return matchSuccess(); | ||||
|   }; | ||||
| }; | ||||
| } // end anonymous namespace
 | ||||
| 
 | ||||
| /// on the ONNXMaxPoolSingleOutOp.
 | ||||
|  | @ -177,5 +80,5 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns( | |||
| /// on the ONNXConvOp.
 | ||||
| void ONNXConvOp::getCanonicalizationPatterns( | ||||
|     OwningRewritePatternList &results, MLIRContext *context) { | ||||
|   results.insert<SplitConvOpPattern>(context); | ||||
|   results.insert<ConvOpPaddingPattern>(context); | ||||
| } | ||||
|  |  | |||
|  | @ -62,6 +62,10 @@ class insertZerosForNonPaddedDims<int extensionLength>: | |||
| def HasNonZeroInArrayAttr: Constraint<CPred<"hasNonZeroInArrayAttr($_self)">, | ||||
|                                        "has non-zero elements">; | ||||
| 
 | ||||
| // Check that a StrAttr does not contain a specific value. | ||||
| class IsNotStringAttrOfValue<string val>: | ||||
|   Constraint<CPred<"$0.cast<StringAttr>().getValue() != \"" # val # "\"">>; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // Rewrite: | ||||
| // %0 = onnx.MaxPoolSingleOutOp(%D : tensor<DShape>) | ||||
|  | @ -90,7 +94,38 @@ def MaxPoolSingleOutOpPaddingPattern: Pat< | |||
|      $auto_pad, $ceil_mode, $dilation, $kernel_shape, | ||||
|      (createArrayAttrOfZerosFrom $pads), | ||||
|      $storage_order, $strides), | ||||
|   [(HasNonZeroInArrayAttr:$pads)] | ||||
|   [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)] | ||||
| >; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // Rewrite: | ||||
| // %0 = onnx.ConvOp(%D : tensor<DShape>, %K) | ||||
| //     {pads = [b0, b1, ... bK, e0, e1, ..., eK]} -> | ||||
| //         tensor<OutShape> | ||||
| // | ||||
| // as: | ||||
| // %0 = onnx.PadConstantValuePadOp(%D) | ||||
| //     {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} -> | ||||
| //     tensor<DPaddedShape> | ||||
| // %1 = onnx.Conv(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} -> | ||||
| //     tensor<OutShape> | ||||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| def ConvOpPaddingPattern: Pat< | ||||
|   (ONNXConvOp:$res | ||||
|      $x, | ||||
|      $w, $b, $auto_pad, $dilation, $group, $kernel_shape, | ||||
|      $pads, | ||||
|      $strides), | ||||
|   (ONNXConvOp | ||||
|      (ONNXPadConstantValuePadOp $x, | ||||
|         (insertZerosForNonPaddedDims<2> $pads), | ||||
|         (FloatAttrOfValue<0> $res), | ||||
|         (StringAttrOfValue<"constant">)), | ||||
|      $w, $b, $auto_pad, $dilation, $group, $kernel_shape, | ||||
|      (createArrayAttrOfZerosFrom $pads), | ||||
|      $strides), | ||||
|   [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)] | ||||
| >; | ||||
| 
 | ||||
| #endif // ONNX_REWRITE | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue