Replace SplitConvOpPattern by a declarative rewriting rule (#46)

This commit is contained in:
Tung D. Le 2020-03-30 15:23:14 +09:00 committed by GitHub
parent bab2241b20
commit 867406191f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 99 deletions

View File

@ -70,103 +70,6 @@ ArrayAttr insertZerosForNonPaddedDims(
/// Include the patterns defined in the Declarative Rewrite framework. /// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXRewrite.inc" #include "src/Transform/ONNX/ONNXRewrite.inc"
//===----------------------------------------------------------------------===//
// Rewrite:
// %0 = onnx.Conv(%D : tensor<DShape>, %K)
// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
// tensor<OutShape>
//
// as:
// %0 = onnx.PadConstantValuePasOp(%D)
// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
// tensor<DPaddedShape>
// %1 = onnx.Conv(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} ->
// tensor<OutShape>
//===----------------------------------------------------------------------===//
struct SplitConvOpPattern : public RewritePattern {
SplitConvOpPattern(MLIRContext *context)
: RewritePattern(ONNXConvOp::getOperationName(),
{ONNXPadConstantValuePadOp::getOperationName(),
ONNXConvOp::getOperationName()},
1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
// If convolution does not use padding then no rewrite is required.
ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);
auto padsAttribute = convOp.padsAttr();
if (!padsAttribute)
return matchFailure();
// If auto_pad is VALID then no padding happens and no rewrite isrequired.
auto autoPad = convOp.auto_pad();
if (autoPad == "VALID")
return matchFailure();
auto data = op->getOperands()[0];
auto inputShape = data.getType().cast<TensorType>().getShape();
// Dimensionality of the input:
// inputRank
// |----------------------|
// D : (N x C x D1 x D2 x ... DK)
// |______________|
// inputDims
//
int64_t inputRank = inputShape.size();
int64_t inputDims = inputRank - 2;
// If all pads values are equal to zero then no rewrite is required.
bool allZeros = true;
for (auto padsValue : padsAttribute.getValue()) {
if (padsValue.cast<IntegerAttr>().getInt() > 0) {
allZeros = false;
break;
}
}
if (allZeros)
return matchFailure();
// Create padding vector for the explicit padding op attribute.
SmallVector<int64_t, 4> pads(2 * inputRank, 0);
SmallVector<int64_t, 4> outPaddedShape(inputRank, 0);
outPaddedShape[0] = inputShape[0];
outPaddedShape[1] = inputShape[1];
for (int i = 0; i < inputDims; ++i) {
int64_t beginPad =
padsAttribute.getValue()[i].cast<IntegerAttr>().getInt();
int64_t endPad =
padsAttribute.getValue()[inputDims + i].cast<IntegerAttr>().getInt();
pads[i + 2] = beginPad;
pads[inputRank + i + 2] = endPad;
outPaddedShape[i + 2] += beginPad + inputShape[i + 2] + endPad;
}
// Create padding operation.
auto inputElemType = data.getType().cast<TensorType>().getElementType();
ONNXPadConstantValuePadOp paddingOp =
rewriter.create<ONNXPadConstantValuePadOp>(
loc, RankedTensorType::get(outPaddedShape, inputElemType), data,
rewriter.getI64ArrayAttr(pads), FloatAttr::get(inputElemType, 0),
StringAttr::get("constant", loc->getContext()));
SmallVector<int64_t, 4> newConvPads(2 * inputDims, 0);
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
ONNXConvOp newConvOp = rewriter.create<ONNXConvOp>(
loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1],
convOp.getOperands()[2],
convOp.auto_padAttr(), convOp.dilationsAttr(),
convOp.groupAttr(), convOp.kernel_shapeAttr(),
rewriter.getI64ArrayAttr(newConvPads),
convOp.stridesAttr());
rewriter.replaceOp(op, newConvOp.getResult());
return matchSuccess();
};
};
} // end anonymous namespace } // end anonymous namespace
/// on the ONNXMaxPoolSingleOutOp. /// on the ONNXMaxPoolSingleOutOp.
@ -177,5 +80,5 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns(
/// on the ONNXConvOp. /// on the ONNXConvOp.
void ONNXConvOp::getCanonicalizationPatterns( void ONNXConvOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SplitConvOpPattern>(context); results.insert<ConvOpPaddingPattern>(context);
} }

View File

@ -62,6 +62,10 @@ class insertZerosForNonPaddedDims<int extensionLength>:
def HasNonZeroInArrayAttr: Constraint<CPred<"hasNonZeroInArrayAttr($_self)">, def HasNonZeroInArrayAttr: Constraint<CPred<"hasNonZeroInArrayAttr($_self)">,
"has non-zero elements">; "has non-zero elements">;
// Check that a StrAttr does not contain a specific value.
class IsNotStringAttrOfValue<string val>:
Constraint<CPred<"$0.cast<StringAttr>().getValue() != \"" # val # "\"">>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Rewrite: // Rewrite:
// %0 = onnx.MaxPoolSingleOutOp(%D : tensor<DShape>) // %0 = onnx.MaxPoolSingleOutOp(%D : tensor<DShape>)
@ -90,7 +94,38 @@ def MaxPoolSingleOutOpPaddingPattern: Pat<
$auto_pad, $ceil_mode, $dilation, $kernel_shape, $auto_pad, $ceil_mode, $dilation, $kernel_shape,
(createArrayAttrOfZerosFrom $pads), (createArrayAttrOfZerosFrom $pads),
$storage_order, $strides), $storage_order, $strides),
[(HasNonZeroInArrayAttr:$pads)] [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
>;
//===----------------------------------------------------------------------===//
// Rewrite:
// %0 = onnx.ConvOp(%D : tensor<DShape>, %K)
// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
// tensor<OutShape>
//
// as:
// %0 = onnx.PadConstantValuePadOp(%D)
// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
// tensor<DPaddedShape>
// %1 = onnx.Conv(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} ->
// tensor<OutShape>
//===----------------------------------------------------------------------===//
def ConvOpPaddingPattern: Pat<
(ONNXConvOp:$res
$x,
$w, $b, $auto_pad, $dilation, $group, $kernel_shape,
$pads,
$strides),
(ONNXConvOp
(ONNXPadConstantValuePadOp $x,
(insertZerosForNonPaddedDims<2> $pads),
(FloatAttrOfValue<0> $res),
(StringAttrOfValue<"constant">)),
$w, $b, $auto_pad, $dilation, $group, $kernel_shape,
(createArrayAttrOfZerosFrom $pads),
$strides),
[(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
>; >;
#endif // ONNX_REWRITE #endif // ONNX_REWRITE