Replace SplitConvOpPattern by a declarative rewriting rule (#46)
This commit is contained in:
parent
bab2241b20
commit
867406191f
|
@ -70,103 +70,6 @@ ArrayAttr insertZerosForNonPaddedDims(
|
||||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||||
#include "src/Transform/ONNX/ONNXRewrite.inc"
|
#include "src/Transform/ONNX/ONNXRewrite.inc"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Rewrite:
|
|
||||||
// %0 = onnx.Conv(%D : tensor<DShape>, %K)
|
|
||||||
// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
|
|
||||||
// tensor<OutShape>
|
|
||||||
//
|
|
||||||
// as:
|
|
||||||
// %0 = onnx.PadConstantValuePasOp(%D)
|
|
||||||
// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
|
|
||||||
// tensor<DPaddedShape>
|
|
||||||
// %1 = onnx.Conv(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} ->
|
|
||||||
// tensor<OutShape>
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
struct SplitConvOpPattern : public RewritePattern {
|
|
||||||
SplitConvOpPattern(MLIRContext *context)
|
|
||||||
: RewritePattern(ONNXConvOp::getOperationName(),
|
|
||||||
{ONNXPadConstantValuePadOp::getOperationName(),
|
|
||||||
ONNXConvOp::getOperationName()},
|
|
||||||
1, context) {}
|
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
|
|
||||||
// If convolution does not use padding then no rewrite is required.
|
|
||||||
ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);
|
|
||||||
auto padsAttribute = convOp.padsAttr();
|
|
||||||
if (!padsAttribute)
|
|
||||||
return matchFailure();
|
|
||||||
|
|
||||||
// If auto_pad is VALID then no padding happens and no rewrite isrequired.
|
|
||||||
auto autoPad = convOp.auto_pad();
|
|
||||||
if (autoPad == "VALID")
|
|
||||||
return matchFailure();
|
|
||||||
|
|
||||||
auto data = op->getOperands()[0];
|
|
||||||
auto inputShape = data.getType().cast<TensorType>().getShape();
|
|
||||||
|
|
||||||
// Dimensionality of the input:
|
|
||||||
// inputRank
|
|
||||||
// |----------------------|
|
|
||||||
// D : (N x C x D1 x D2 x ... DK)
|
|
||||||
// |______________|
|
|
||||||
// inputDims
|
|
||||||
//
|
|
||||||
int64_t inputRank = inputShape.size();
|
|
||||||
int64_t inputDims = inputRank - 2;
|
|
||||||
|
|
||||||
// If all pads values are equal to zero then no rewrite is required.
|
|
||||||
bool allZeros = true;
|
|
||||||
for (auto padsValue : padsAttribute.getValue()) {
|
|
||||||
if (padsValue.cast<IntegerAttr>().getInt() > 0) {
|
|
||||||
allZeros = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (allZeros)
|
|
||||||
return matchFailure();
|
|
||||||
|
|
||||||
// Create padding vector for the explicit padding op attribute.
|
|
||||||
SmallVector<int64_t, 4> pads(2 * inputRank, 0);
|
|
||||||
SmallVector<int64_t, 4> outPaddedShape(inputRank, 0);
|
|
||||||
outPaddedShape[0] = inputShape[0];
|
|
||||||
outPaddedShape[1] = inputShape[1];
|
|
||||||
for (int i = 0; i < inputDims; ++i) {
|
|
||||||
int64_t beginPad =
|
|
||||||
padsAttribute.getValue()[i].cast<IntegerAttr>().getInt();
|
|
||||||
int64_t endPad =
|
|
||||||
padsAttribute.getValue()[inputDims + i].cast<IntegerAttr>().getInt();
|
|
||||||
pads[i + 2] = beginPad;
|
|
||||||
pads[inputRank + i + 2] = endPad;
|
|
||||||
outPaddedShape[i + 2] += beginPad + inputShape[i + 2] + endPad;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create padding operation.
|
|
||||||
auto inputElemType = data.getType().cast<TensorType>().getElementType();
|
|
||||||
ONNXPadConstantValuePadOp paddingOp =
|
|
||||||
rewriter.create<ONNXPadConstantValuePadOp>(
|
|
||||||
loc, RankedTensorType::get(outPaddedShape, inputElemType), data,
|
|
||||||
rewriter.getI64ArrayAttr(pads), FloatAttr::get(inputElemType, 0),
|
|
||||||
StringAttr::get("constant", loc->getContext()));
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> newConvPads(2 * inputDims, 0);
|
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
|
||||||
ONNXConvOp newConvOp = rewriter.create<ONNXConvOp>(
|
|
||||||
loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1],
|
|
||||||
convOp.getOperands()[2],
|
|
||||||
convOp.auto_padAttr(), convOp.dilationsAttr(),
|
|
||||||
convOp.groupAttr(), convOp.kernel_shapeAttr(),
|
|
||||||
rewriter.getI64ArrayAttr(newConvPads),
|
|
||||||
convOp.stridesAttr());
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, newConvOp.getResult());
|
|
||||||
return matchSuccess();
|
|
||||||
};
|
|
||||||
};
|
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
/// on the ONNXMaxPoolSingleOutOp.
|
/// on the ONNXMaxPoolSingleOutOp.
|
||||||
|
@ -177,5 +80,5 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns(
|
||||||
/// on the ONNXConvOp.
|
/// on the ONNXConvOp.
|
||||||
void ONNXConvOp::getCanonicalizationPatterns(
|
void ONNXConvOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.insert<SplitConvOpPattern>(context);
|
results.insert<ConvOpPaddingPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,6 +62,10 @@ class insertZerosForNonPaddedDims<int extensionLength>:
|
||||||
def HasNonZeroInArrayAttr: Constraint<CPred<"hasNonZeroInArrayAttr($_self)">,
|
def HasNonZeroInArrayAttr: Constraint<CPred<"hasNonZeroInArrayAttr($_self)">,
|
||||||
"has non-zero elements">;
|
"has non-zero elements">;
|
||||||
|
|
||||||
|
// Check that a StrAttr does not contain a specific value.
|
||||||
|
class IsNotStringAttrOfValue<string val>:
|
||||||
|
Constraint<CPred<"$0.cast<StringAttr>().getValue() != \"" # val # "\"">>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Rewrite:
|
// Rewrite:
|
||||||
// %0 = onnx.MaxPoolSingleOutOp(%D : tensor<DShape>)
|
// %0 = onnx.MaxPoolSingleOutOp(%D : tensor<DShape>)
|
||||||
|
@ -90,7 +94,38 @@ def MaxPoolSingleOutOpPaddingPattern: Pat<
|
||||||
$auto_pad, $ceil_mode, $dilation, $kernel_shape,
|
$auto_pad, $ceil_mode, $dilation, $kernel_shape,
|
||||||
(createArrayAttrOfZerosFrom $pads),
|
(createArrayAttrOfZerosFrom $pads),
|
||||||
$storage_order, $strides),
|
$storage_order, $strides),
|
||||||
[(HasNonZeroInArrayAttr:$pads)]
|
[(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
|
||||||
|
>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Rewrite:
|
||||||
|
// %0 = onnx.ConvOp(%D : tensor<DShape>, %K)
|
||||||
|
// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
|
||||||
|
// tensor<OutShape>
|
||||||
|
//
|
||||||
|
// as:
|
||||||
|
// %0 = onnx.PadConstantValuePadOp(%D)
|
||||||
|
// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
|
||||||
|
// tensor<DPaddedShape>
|
||||||
|
// %1 = onnx.Conv(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} ->
|
||||||
|
// tensor<OutShape>
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ConvOpPaddingPattern: Pat<
|
||||||
|
(ONNXConvOp:$res
|
||||||
|
$x,
|
||||||
|
$w, $b, $auto_pad, $dilation, $group, $kernel_shape,
|
||||||
|
$pads,
|
||||||
|
$strides),
|
||||||
|
(ONNXConvOp
|
||||||
|
(ONNXPadConstantValuePadOp $x,
|
||||||
|
(insertZerosForNonPaddedDims<2> $pads),
|
||||||
|
(FloatAttrOfValue<0> $res),
|
||||||
|
(StringAttrOfValue<"constant">)),
|
||||||
|
$w, $b, $auto_pad, $dilation, $group, $kernel_shape,
|
||||||
|
(createArrayAttrOfZerosFrom $pads),
|
||||||
|
$strides),
|
||||||
|
[(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
|
||||||
>;
|
>;
|
||||||
|
|
||||||
#endif // ONNX_REWRITE
|
#endif // ONNX_REWRITE
|
||||||
|
|
Loading…
Reference in New Issue