Split convolution into explicit padding and unpaded convolution. (#82)
* Split convolution into explicit padding and unpaded convolution. * Refactor code. Add test.
This commit is contained in:
parent
17d84901b7
commit
3c505ae31d
|
@ -111,6 +111,7 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
||||||
|
|
||||||
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX Conv operation with no Bias operand.";
|
let summary = "ONNX Conv operation with no Bias operand.";
|
||||||
let description = [{
|
let description = [{
|
||||||
"The convolution operator consumes an input tensor and a filter, and"
|
"The convolution operator consumes an input tensor and a filter, and"
|
||||||
|
|
|
@ -39,14 +39,14 @@ def GemmTransB : NativeCodeCall<"$_builder.getI64IntegerAttr(0)">;
|
||||||
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z)
|
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z)
|
||||||
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
||||||
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
|
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
|
||||||
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
|
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
|
||||||
|
|
||||||
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
||||||
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
||||||
(replaceWithValue $arg)>;
|
(replaceWithValue $arg)>;
|
||||||
|
|
||||||
def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3),
|
def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3),
|
||||||
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
|
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
|
||||||
[(HasOneUse $res)]>;
|
[(HasOneUse $res)]>;
|
||||||
|
|
||||||
#endif // ONNX_COMBINE
|
#endif // ONNX_COMBINE
|
||||||
|
|
|
@ -263,6 +263,103 @@ struct ReduceSumSquareOpPattern : public RewritePattern {
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Rewrite:
|
||||||
|
// %0 = onnx.ConvNoBiasOp(%D : tensor<DShape>, %K)
|
||||||
|
// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
|
||||||
|
// tensor<OutShape>
|
||||||
|
//
|
||||||
|
// as:
|
||||||
|
// %0 = onnx.PadConstantValuePasOp(%D)
|
||||||
|
// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
|
||||||
|
// tensor<DPaddedShape>
|
||||||
|
// %1 = onnx.ConvNoBias(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} ->
|
||||||
|
// tensor<OutShape>
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
struct SplitConvOpPattern : public RewritePattern {
|
||||||
|
SplitConvOpPattern(MLIRContext *context)
|
||||||
|
: RewritePattern(ONNXConvNoBiasOp::getOperationName(),
|
||||||
|
{ONNXPadConstantValuePadOp::getOperationName(),
|
||||||
|
ONNXConvNoBiasOp::getOperationName()},
|
||||||
|
1, context) {}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
// If convolution does not use padding then no rewrite is required.
|
||||||
|
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);
|
||||||
|
auto padsAttribute = convOp.padsAttr();
|
||||||
|
if (!padsAttribute)
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
// If auto_pad is VALID then no padding happens and no rewrite isrequired.
|
||||||
|
auto autoPad = convOp.auto_pad();
|
||||||
|
if (autoPad == "VALID")
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
auto data = op->getOperands()[0];
|
||||||
|
auto inputShape = data.getType().cast<TensorType>().getShape();
|
||||||
|
|
||||||
|
// Dimensionality of the input:
|
||||||
|
// inputRank
|
||||||
|
// |----------------------|
|
||||||
|
// D : (N x C x D1 x D2 x ... DK)
|
||||||
|
// |______________|
|
||||||
|
// inputDims
|
||||||
|
//
|
||||||
|
int64_t inputRank = inputShape.size();
|
||||||
|
int64_t inputDims = inputRank - 2;
|
||||||
|
|
||||||
|
// If all pads values are equal to zero then no rewrite is required.
|
||||||
|
bool allZeros = true;
|
||||||
|
for (auto padsValue : padsAttribute.getValue()) {
|
||||||
|
if (padsValue.cast<IntegerAttr>().getInt() > 0) {
|
||||||
|
allZeros = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allZeros)
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
// Create padding vector for the explicit padding op attribute.
|
||||||
|
SmallVector<int64_t, 4> pads(2 * inputRank, 0);
|
||||||
|
SmallVector<int64_t, 4> outPaddedShape(inputRank, 0);
|
||||||
|
outPaddedShape[0] = inputShape[0];
|
||||||
|
outPaddedShape[1] = inputShape[1];
|
||||||
|
for (int i = 0; i < inputDims; ++i) {
|
||||||
|
int64_t beginPad =
|
||||||
|
padsAttribute.getValue()[i].cast<IntegerAttr>().getInt();
|
||||||
|
int64_t endPad =
|
||||||
|
padsAttribute.getValue()[inputDims + i].cast<IntegerAttr>().getInt();
|
||||||
|
pads[i + 2] = beginPad;
|
||||||
|
pads[inputRank + i + 2] = endPad;
|
||||||
|
outPaddedShape[i + 2] += beginPad + inputShape[i + 2] + endPad;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create padding operation.
|
||||||
|
auto inputElemType = data.getType().cast<TensorType>().getElementType();
|
||||||
|
ONNXPadConstantValuePadOp paddingOp =
|
||||||
|
rewriter.create<ONNXPadConstantValuePadOp>(
|
||||||
|
loc, RankedTensorType::get(outPaddedShape, inputElemType), data,
|
||||||
|
rewriter.getI64ArrayAttr(pads), FloatAttr::get(inputElemType, 0),
|
||||||
|
StringAttr::get("constant", loc->getContext()));
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> newConvPads(2 * inputDims, 0);
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
ONNXConvNoBiasOp newConvOp = rewriter.create<ONNXConvNoBiasOp>(
|
||||||
|
loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1],
|
||||||
|
convOp.auto_padAttr(), convOp.dilationsAttr(),
|
||||||
|
convOp.groupAttr(), convOp.kernel_shapeAttr(),
|
||||||
|
rewriter.getI64ArrayAttr(newConvPads),
|
||||||
|
convOp.stridesAttr());
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, newConvOp.getResult());
|
||||||
|
return matchSuccess();
|
||||||
|
};
|
||||||
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
/// on the ONNXReduceL1Op.
|
/// on the ONNXReduceL1Op.
|
||||||
|
@ -293,3 +390,9 @@ void ONNXReduceSumSquareOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.insert<ReduceSumSquareOpPattern>(context);
|
results.insert<ReduceSumSquareOpPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// on the ONNXReduceSumSquareOp.
|
||||||
|
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<SplitConvOpPattern>(context);
|
||||||
|
}
|
||||||
|
|
|
@ -92,3 +92,12 @@ func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
|
||||||
%2 = "onnx.PadConstantValue"(%arg0, %0) {constant_value=0. : f32, mode = "constant"} : (tensor<?x?xf32>, tensor<?xi64>)-> tensor<*xf32>
|
%2 = "onnx.PadConstantValue"(%arg0, %0) {constant_value=0. : f32, mode = "constant"} : (tensor<?x?xf32>, tensor<?xi64>)-> tensor<*xf32>
|
||||||
"std.return"(%2) : (tensor<*xf32>) -> ()
|
"std.return"(%2) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_conv_split(%{{.*}}: tensor<1x9x32x64xf32>, %{{.*}}: tensor<5x9x6x7xf32>) -> tensor<*xf32> {
|
||||||
|
func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
// CHECK-NEXT: %0 = "onnx.PadConstatValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32>
|
||||||
|
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: return %1 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue