Support Pads for MaxPoolSingleOut (#14)
* Support Pads for MaxPoolSingleOut * Regenerate onnx.md to include the new op * Edit comments * Undo redundant parts that were unintentionally changed * Move declarative rewriting rules into canonicalize to avoid creating a new op * Reformat the rewriting rule pattern of MaxPoolSingleOut * Put ONNXPadConstantValuePadOp's build method into a .cpp file instead of a tablegen file * Use the same helper function as the one in inferShape for the ONNXPadConstantValuePadOp's build method * Change function names and fix padding for the spatial dimensions * Call shape-inference again after canonicalization to infer shape for newly added ops during canonicalization. * Fix typos
This commit is contained in:
parent
718ec85479
commit
1882059ac9
|
@ -111,6 +111,7 @@ def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
||||||
|
|
||||||
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX MaxPool operation with a single output.";
|
let summary = "ONNX MaxPool operation with a single output.";
|
||||||
let description = [{
|
let description = [{
|
||||||
"ONNX MaxPool operation with a single output."
|
"ONNX MaxPool operation with a single output."
|
||||||
|
@ -195,6 +196,10 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
|
||||||
DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
|
DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
|
||||||
DefaultValuedAttr<StrAttr, "constant">:$mode);
|
DefaultValuedAttr<StrAttr, "constant">:$mode);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
||||||
|
// A build method with the result type deduction.
|
||||||
|
let builders = [OpBuilder<"Builder *builder, OperationState &state, "
|
||||||
|
"Value data, ArrayAttr pads, "
|
||||||
|
"FloatAttr constant_value, StringAttr mode">];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1095,6 +1095,16 @@ void ONNXPadConstantValuePadOp::inferShapes(){
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
|
||||||
|
Value data, ArrayAttr pads, FloatAttr constant_value, StringAttr mode) {
|
||||||
|
Type outputType = padShapeInferenceHelper(data, pads);
|
||||||
|
if (!outputType) {
|
||||||
|
auto elementType = data.getType().cast<TensorType>().getElementType();
|
||||||
|
outputType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, outputType, data, pads, constant_value, mode);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Unsqueeze
|
// Unsqueeze
|
||||||
|
|
|
@ -125,6 +125,7 @@ int main(int argc, char *argv[]) {
|
||||||
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
|
|
||||||
if (emissionTarget >= EmitMLIR) {
|
if (emissionTarget >= EmitMLIR) {
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
|
|
|
@ -17,6 +17,56 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Check whether an ArrayAttr contains non-zero values or not.
|
||||||
|
bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
|
||||||
|
bool allZeros = true;
|
||||||
|
if (attrs) {
|
||||||
|
for (auto attr: attrs.getValue()) {
|
||||||
|
if (attr.cast<IntegerAttr>().getInt() > 0) {
|
||||||
|
allZeros = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return !allZeros;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an ArrayAttr of IntergerAttr(s) of zero values.
|
||||||
|
// This function is used for padding attribute in MaxPoolSingleOut.
|
||||||
|
ArrayAttr createArrayAttrOfZeros(
|
||||||
|
PatternRewriter &rewriter, ArrayAttr origAttrs) {
|
||||||
|
int nElements = origAttrs.getValue().size();
|
||||||
|
SmallVector<int64_t, 4> vals(nElements, 0);
|
||||||
|
return rewriter.getI64ArrayAttr(vals);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pad a ArrayAttr with zeros.
|
||||||
|
//
|
||||||
|
// pads = [B1, B2, ... Bk, E1, E2, ..., Ek]
|
||||||
|
//
|
||||||
|
// becomes:
|
||||||
|
//
|
||||||
|
// pads = [0,... 0, B1, B2, ... Bk, 0,... 0, E1, E2, ..., Ek]
|
||||||
|
// |_____| |_____|
|
||||||
|
// nZeros nZeros
|
||||||
|
//
|
||||||
|
// This function is used for padding attribute in MaxPoolSingleOut.
|
||||||
|
ArrayAttr insertZerosForNonPaddedDims(
|
||||||
|
PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) {
|
||||||
|
int nDims = (int) origAttrs.getValue().size() / 2;
|
||||||
|
int nElements = (nDims + extensionLength) * 2;
|
||||||
|
SmallVector<int64_t, 4> pads(nElements, 0);
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
int64_t beginPad = origAttrs.getValue()[i].cast<IntegerAttr>().getInt();
|
||||||
|
int64_t endPad =
|
||||||
|
origAttrs.getValue()[nDims + i].cast<IntegerAttr>().getInt();
|
||||||
|
pads[i + extensionLength] = beginPad;
|
||||||
|
pads[nDims + extensionLength + i + extensionLength] = endPad;
|
||||||
|
}
|
||||||
|
return rewriter.getI64ArrayAttr(pads);
|
||||||
|
}
|
||||||
|
|
||||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||||
#include "src/onnx_rewrite.inc"
|
#include "src/onnx_rewrite.inc"
|
||||||
|
|
||||||
|
@ -118,6 +168,11 @@ struct SplitConvOpPattern : public RewritePattern {
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/// on the ONNXMaxPoolSingleOutOp.
|
||||||
|
void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<MaxPoolSingleOutOpPaddingPattern>(context);
|
||||||
|
}
|
||||||
/// on the ONNXReduceSumSquareOp.
|
/// on the ONNXReduceSumSquareOp.
|
||||||
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
|
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
|
|
@ -24,4 +24,68 @@ include "dialect/onnx/onnx.td"
|
||||||
/// dag benefitsAdded = (addBenefit 0)
|
/// dag benefitsAdded = (addBenefit 0)
|
||||||
/// >;
|
/// >;
|
||||||
|
|
||||||
|
// Create a StringAttr from a string.
|
||||||
|
class StringAttrOfValue<string val>:
|
||||||
|
NativeCodeCall<"$_builder.getStringAttr(\"" # val # "\")">;
|
||||||
|
|
||||||
|
// Create a FloatAttr from an interger value.
|
||||||
|
// It seems Table-gen does not support `float` type, so we can not pass a float value.
|
||||||
|
class FloatAttrOfValue<int val>:
|
||||||
|
NativeCodeCall<"FloatAttr::get($0.getType().cast<TensorType>().getElementType(), " # val # ")">;
|
||||||
|
|
||||||
|
// Create an ArrayAttr of IntergerAttr(s) of zero values.
|
||||||
|
// This function is used for padding attribute in MaxPoolSingleOut.
|
||||||
|
def createArrayAttrOfZerosFrom:
|
||||||
|
NativeCodeCall<"createArrayAttrOfZeros($_builder, $0)">;
|
||||||
|
|
||||||
|
// Pad a ArrayAttr with zeros.
|
||||||
|
//
|
||||||
|
// pads = [B1, B2, ... Bk, E1, E2, ..., Ek]
|
||||||
|
//
|
||||||
|
// becomes:
|
||||||
|
//
|
||||||
|
// pads = [0,... 0, B1, B2, ... Bk, 0,... 0, E1, E2, ..., Ek]
|
||||||
|
// |_____| |_____|
|
||||||
|
// nZeros nZeros
|
||||||
|
//
|
||||||
|
// This function is used for padding attribute in MaxPoolSingleOut.
|
||||||
|
class insertZerosForNonPaddedDims<int extensionLength>:
|
||||||
|
NativeCodeCall<"insertZerosForNonPaddedDims($_builder, $0,"
|
||||||
|
# extensionLength # ")">;
|
||||||
|
|
||||||
|
// Check whether an ArrayAttr contains non-zero values or not.
|
||||||
|
def HasNonZeroInArrayAttr: Constraint<CPred<"hasNonZeroInArrayAttr($_self)">,
|
||||||
|
"has non-zero elements">;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Rewrite:
|
||||||
|
// %0 = onnx.MaxPoolSingleOutOp(%D : tensor<DShape>)
|
||||||
|
// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
|
||||||
|
// tensor<OutShape>
|
||||||
|
//
|
||||||
|
// as:
|
||||||
|
// %0 = onnx.PadConstantValuePadOp(%D)
|
||||||
|
// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
|
||||||
|
// tensor<DPaddedShape>
|
||||||
|
// %1 = onnx.MaxPoolSingleOut(%0 : tensor<DPaddedShape>) {pads = [0, ..., 0]} ->
|
||||||
|
// tensor<OutShape>
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def MaxPoolSingleOutOpPaddingPattern: Pat<
|
||||||
|
(ONNXMaxPoolSingleOutOp:$res
|
||||||
|
$x,
|
||||||
|
$auto_pad, $ceil_mode, $dilation, $kernel_shape,
|
||||||
|
$pads,
|
||||||
|
$storage_order, $strides),
|
||||||
|
(ONNXMaxPoolSingleOutOp
|
||||||
|
(ONNXPadConstantValuePadOp $x,
|
||||||
|
(insertZerosForNonPaddedDims<2> $pads),
|
||||||
|
(FloatAttrOfValue<0> $res),
|
||||||
|
(StringAttrOfValue<"constant">)),
|
||||||
|
$auto_pad, $ceil_mode, $dilation, $kernel_shape,
|
||||||
|
(createArrayAttrOfZerosFrom $pads),
|
||||||
|
$storage_order, $strides),
|
||||||
|
[(HasNonZeroInArrayAttr:$pads)]
|
||||||
|
>;
|
||||||
|
|
||||||
#endif // ONNX_REWRITE
|
#endif // ONNX_REWRITE
|
||||||
|
|
|
@ -77,3 +77,23 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
|
||||||
// return [[GEMM]] : tensor<*xf32>
|
// return [[GEMM]] : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//CHECK-LABEL: @test_maxpoolsingleout_split(%{{.*}}: tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32> {
|
||||||
|
func @test_maxpoolsingleout_split(%arg0: tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32> {
|
||||||
|
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0, kernel_shape = [5,3], pads = [1, 2, 3, 4] } : (tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32>
|
||||||
|
"std.return"(%0) : (tensor<5x8x32x39xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 1, 2, 0, 0, 3, 4]} : (tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32>
|
||||||
|
// CHECK-NEXT: %1 = "onnx.MaxPoolSingleOut"(%0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [0, 0, 0, 0], storage_order = 0 : i64} : (tensor<5x8x32x39xf32>) -> tensor<5x8x32x39xf32>
|
||||||
|
// CHECK-NEXT: return %1 : tensor<5x8x32x39xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
//CHECK-LABEL: @test_maxpoolsingleout_split_unknown_dims(%{{.*}}: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
func @test_maxpoolsingleout_split_unknown_dims(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0, kernel_shape = [5,3], pads = [1, 2, 3, 4] } : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 1, 2, 0, 0, 3, 4]} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %1 = "onnx.MaxPoolSingleOut"(%0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [0, 0, 0, 0], storage_order = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: return %1 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue