Support Pads for MaxPoolSingleOut (#14)

* Support Pads for MaxPoolSingleOut

* Regenerate onnx.md to include the new op

* Edit comments

* Undo redundant parts that were unintentionally changed

* Move declarative rewriting rules into canonicalize to avoid creating a new op

* Reformat the rewriting rule pattern of MaxPoolSingleOut

* Put ONNXPadConstantValuePadOp's build method into a .cpp file instead of a tablegen file

* Use the same helper function as the one in inferShape for the ONNXPadConstantValuePadOp's build method

* Change function names and fix padding for the spatial dimensions

* Call shape-inference again after canonicalization to infer shape for newly added ops during canonicalization.

* Fix typos
This commit is contained in:
Tung D. Le 2020-03-10 09:15:58 +09:00 committed by GitHub
parent 718ec85479
commit 1882059ac9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 155 additions and 0 deletions

View File

@ -111,6 +111,7 @@ def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX MaxPool operation with a single output."; let summary = "ONNX MaxPool operation with a single output.";
let description = [{ let description = [{
"ONNX MaxPool operation with a single output." "ONNX MaxPool operation with a single output."
@ -195,6 +196,10 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
DefaultValuedAttr<F32Attr, "0.0">:$constant_value, DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
DefaultValuedAttr<StrAttr, "constant">:$mode); DefaultValuedAttr<StrAttr, "constant">:$mode);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
// A build method with the result type deduction.
let builders = [OpBuilder<"Builder *builder, OperationState &state, "
"Value data, ArrayAttr pads, "
"FloatAttr constant_value, StringAttr mode">];
} }

View File

@ -1095,6 +1095,16 @@ void ONNXPadConstantValuePadOp::inferShapes(){
return; return;
} }
void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
Value data, ArrayAttr pads, FloatAttr constant_value, StringAttr mode) {
Type outputType = padShapeInferenceHelper(data, pads);
if (!outputType) {
auto elementType = data.getType().cast<TensorType>().getElementType();
outputType = UnrankedTensorType::get(elementType);
}
build(builder, state, outputType, data, pads, constant_value, mode);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Unsqueeze // Unsqueeze

View File

@ -125,6 +125,7 @@ int main(int argc, char *argv[]) {
pm.addPass(mlir::createDecomposeONNXToONNXPass()); pm.addPass(mlir::createDecomposeONNXToONNXPass());
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createShapeInferencePass());
if (emissionTarget >= EmitMLIR) { if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass()); pm.addPass(mlir::createLowerToKrnlPass());

View File

@ -17,6 +17,56 @@
using namespace mlir; using namespace mlir;
namespace { namespace {
// Check whether an ArrayAttr contains non-zero values or not.
bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
bool allZeros = true;
if (attrs) {
for (auto attr: attrs.getValue()) {
if (attr.cast<IntegerAttr>().getInt() > 0) {
allZeros = false;
break;
}
}
}
return !allZeros;
}
// Create an ArrayAttr of IntergerAttr(s) of zero values.
// This function is used for padding attribute in MaxPoolSingleOut.
ArrayAttr createArrayAttrOfZeros(
PatternRewriter &rewriter, ArrayAttr origAttrs) {
int nElements = origAttrs.getValue().size();
SmallVector<int64_t, 4> vals(nElements, 0);
return rewriter.getI64ArrayAttr(vals);
}
// Pad a ArrayAttr with zeros.
//
// pads = [B1, B2, ... Bk, E1, E2, ..., Ek]
//
// becomes:
//
// pads = [0,... 0, B1, B2, ... Bk, 0,... 0, E1, E2, ..., Ek]
// |_____| |_____|
// nZeros nZeros
//
// This function is used for padding attribute in MaxPoolSingleOut.
ArrayAttr insertZerosForNonPaddedDims(
PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) {
int nDims = (int) origAttrs.getValue().size() / 2;
int nElements = (nDims + extensionLength) * 2;
SmallVector<int64_t, 4> pads(nElements, 0);
for (int i = 0; i < nDims; ++i) {
int64_t beginPad = origAttrs.getValue()[i].cast<IntegerAttr>().getInt();
int64_t endPad =
origAttrs.getValue()[nDims + i].cast<IntegerAttr>().getInt();
pads[i + extensionLength] = beginPad;
pads[nDims + extensionLength + i + extensionLength] = endPad;
}
return rewriter.getI64ArrayAttr(pads);
}
/// Include the patterns defined in the Declarative Rewrite framework. /// Include the patterns defined in the Declarative Rewrite framework.
#include "src/onnx_rewrite.inc" #include "src/onnx_rewrite.inc"
@ -118,6 +168,11 @@ struct SplitConvOpPattern : public RewritePattern {
}; };
} // end anonymous namespace } // end anonymous namespace
/// on the ONNXMaxPoolSingleOutOp.
void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<MaxPoolSingleOutOpPaddingPattern>(context);
}
/// on the ONNXReduceSumSquareOp. /// on the ONNXReduceSumSquareOp.
void ONNXConvNoBiasOp::getCanonicalizationPatterns( void ONNXConvNoBiasOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {

View File

@ -24,4 +24,68 @@ include "dialect/onnx/onnx.td"
/// dag benefitsAdded = (addBenefit 0) /// dag benefitsAdded = (addBenefit 0)
/// >; /// >;
// Create a StringAttr from a string.
class StringAttrOfValue<string val>:
NativeCodeCall<"$_builder.getStringAttr(\"" # val # "\")">;
// Create a FloatAttr from an interger value.
// It seems Table-gen does not support `float` type, so we can not pass a float value.
class FloatAttrOfValue<int val>:
NativeCodeCall<"FloatAttr::get($0.getType().cast<TensorType>().getElementType(), " # val # ")">;
// Create an ArrayAttr of IntergerAttr(s) of zero values.
// This function is used for padding attribute in MaxPoolSingleOut.
def createArrayAttrOfZerosFrom:
NativeCodeCall<"createArrayAttrOfZeros($_builder, $0)">;
// Pad a ArrayAttr with zeros.
//
// pads = [B1, B2, ... Bk, E1, E2, ..., Ek]
//
// becomes:
//
// pads = [0,... 0, B1, B2, ... Bk, 0,... 0, E1, E2, ..., Ek]
// |_____| |_____|
// nZeros nZeros
//
// This function is used for padding attribute in MaxPoolSingleOut.
class insertZerosForNonPaddedDims<int extensionLength>:
NativeCodeCall<"insertZerosForNonPaddedDims($_builder, $0,"
# extensionLength # ")">;
// Check whether an ArrayAttr contains non-zero values or not.
def HasNonZeroInArrayAttr: Constraint<CPred<"hasNonZeroInArrayAttr($_self)">,
"has non-zero elements">;
//===----------------------------------------------------------------------===//
// Rewrite:
// %0 = onnx.MaxPoolSingleOutOp(%D : tensor<DShape>)
// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
// tensor<OutShape>
//
// as:
// %0 = onnx.PadConstantValuePadOp(%D)
// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
// tensor<DPaddedShape>
// %1 = onnx.MaxPoolSingleOut(%0 : tensor<DPaddedShape>) {pads = [0, ..., 0]} ->
// tensor<OutShape>
//===----------------------------------------------------------------------===//
def MaxPoolSingleOutOpPaddingPattern: Pat<
(ONNXMaxPoolSingleOutOp:$res
$x,
$auto_pad, $ceil_mode, $dilation, $kernel_shape,
$pads,
$storage_order, $strides),
(ONNXMaxPoolSingleOutOp
(ONNXPadConstantValuePadOp $x,
(insertZerosForNonPaddedDims<2> $pads),
(FloatAttrOfValue<0> $res),
(StringAttrOfValue<"constant">)),
$auto_pad, $ceil_mode, $dilation, $kernel_shape,
(createArrayAttrOfZerosFrom $pads),
$storage_order, $strides),
[(HasNonZeroInArrayAttr:$pads)]
>;
#endif // ONNX_REWRITE #endif // ONNX_REWRITE

View File

@ -77,3 +77,23 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
// return [[GEMM]] : tensor<*xf32> // return [[GEMM]] : tensor<*xf32>
} }
//CHECK-LABEL: @test_maxpoolsingleout_split(%{{.*}}: tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32> {
func @test_maxpoolsingleout_split(%arg0: tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32> {
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0, kernel_shape = [5,3], pads = [1, 2, 3, 4] } : (tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32>
"std.return"(%0) : (tensor<5x8x32x39xf32>) -> ()
// CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 1, 2, 0, 0, 3, 4]} : (tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32>
// CHECK-NEXT: %1 = "onnx.MaxPoolSingleOut"(%0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [0, 0, 0, 0], storage_order = 0 : i64} : (tensor<5x8x32x39xf32>) -> tensor<5x8x32x39xf32>
// CHECK-NEXT: return %1 : tensor<5x8x32x39xf32>
}
//CHECK-LABEL: @test_maxpoolsingleout_split_unknown_dims(%{{.*}}: tensor<*xf32>) -> tensor<*xf32> {
func @test_maxpoolsingleout_split_unknown_dims(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0, kernel_shape = [5,3], pads = [1, 2, 3, 4] } : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 1, 2, 0, 0, 3, 4]} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: %1 = "onnx.MaxPoolSingleOut"(%0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [0, 0, 0, 0], storage_order = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: return %1 : tensor<*xf32>
}