Support Pads for MaxPoolSingleOut (#14)
* Support Pads for MaxPoolSingleOut * Regenerate onnx.md to include the new op * Edit comments * Undo redundant parts that were unintentionally changed * Move declarative rewriting rules into canonicalize to avoid creating a new op * Reformat the rewriting rule pattern of MaxPoolSingleOut * Put ONNXPadConstantValuePadOp's build method into a .cpp file instead of a tablegen file * Use the same helper function as the one in inferShape for the ONNXPadConstantValuePadOp's build method * Change function names and fix padding for the spatial dimensions * Call shape-inference again after canonicalization to infer shape for newly added ops during canonicalization. * Fix typos
This commit is contained in:
parent
718ec85479
commit
1882059ac9
|
@ -111,6 +111,7 @@ def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
|||
|
||||
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let hasCanonicalizer = 1;
|
||||
let summary = "ONNX MaxPool operation with a single output.";
|
||||
let description = [{
|
||||
"ONNX MaxPool operation with a single output."
|
||||
|
@ -195,6 +196,10 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
|
|||
DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
|
||||
DefaultValuedAttr<StrAttr, "constant">:$mode);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
||||
// A build method with the result type deduction.
|
||||
let builders = [OpBuilder<"Builder *builder, OperationState &state, "
|
||||
"Value data, ArrayAttr pads, "
|
||||
"FloatAttr constant_value, StringAttr mode">];
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1095,6 +1095,16 @@ void ONNXPadConstantValuePadOp::inferShapes(){
|
|||
return;
|
||||
}
|
||||
|
||||
void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
|
||||
Value data, ArrayAttr pads, FloatAttr constant_value, StringAttr mode) {
|
||||
Type outputType = padShapeInferenceHelper(data, pads);
|
||||
if (!outputType) {
|
||||
auto elementType = data.getType().cast<TensorType>().getElementType();
|
||||
outputType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, outputType, data, pads, constant_value, mode);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Unsqueeze
|
||||
|
|
|
@ -125,6 +125,7 @@ int main(int argc, char *argv[]) {
|
|||
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
||||
pm.addPass(mlir::createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createShapeInferencePass());
|
||||
|
||||
if (emissionTarget >= EmitMLIR) {
|
||||
pm.addPass(mlir::createLowerToKrnlPass());
|
||||
|
|
|
@ -17,6 +17,56 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
// Check whether an ArrayAttr contains non-zero values or not.
|
||||
bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
|
||||
bool allZeros = true;
|
||||
if (attrs) {
|
||||
for (auto attr: attrs.getValue()) {
|
||||
if (attr.cast<IntegerAttr>().getInt() > 0) {
|
||||
allZeros = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return !allZeros;
|
||||
}
|
||||
|
||||
// Create an ArrayAttr of IntergerAttr(s) of zero values.
|
||||
// This function is used for padding attribute in MaxPoolSingleOut.
|
||||
ArrayAttr createArrayAttrOfZeros(
|
||||
PatternRewriter &rewriter, ArrayAttr origAttrs) {
|
||||
int nElements = origAttrs.getValue().size();
|
||||
SmallVector<int64_t, 4> vals(nElements, 0);
|
||||
return rewriter.getI64ArrayAttr(vals);
|
||||
}
|
||||
|
||||
// Pad a ArrayAttr with zeros.
|
||||
//
|
||||
// pads = [B1, B2, ... Bk, E1, E2, ..., Ek]
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// pads = [0,... 0, B1, B2, ... Bk, 0,... 0, E1, E2, ..., Ek]
|
||||
// |_____| |_____|
|
||||
// nZeros nZeros
|
||||
//
|
||||
// This function is used for padding attribute in MaxPoolSingleOut.
|
||||
ArrayAttr insertZerosForNonPaddedDims(
|
||||
PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) {
|
||||
int nDims = (int) origAttrs.getValue().size() / 2;
|
||||
int nElements = (nDims + extensionLength) * 2;
|
||||
SmallVector<int64_t, 4> pads(nElements, 0);
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
int64_t beginPad = origAttrs.getValue()[i].cast<IntegerAttr>().getInt();
|
||||
int64_t endPad =
|
||||
origAttrs.getValue()[nDims + i].cast<IntegerAttr>().getInt();
|
||||
pads[i + extensionLength] = beginPad;
|
||||
pads[nDims + extensionLength + i + extensionLength] = endPad;
|
||||
}
|
||||
return rewriter.getI64ArrayAttr(pads);
|
||||
}
|
||||
|
||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||
#include "src/onnx_rewrite.inc"
|
||||
|
||||
|
@ -118,6 +168,11 @@ struct SplitConvOpPattern : public RewritePattern {
|
|||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// on the ONNXMaxPoolSingleOutOp.
|
||||
void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<MaxPoolSingleOutOpPaddingPattern>(context);
|
||||
}
|
||||
/// on the ONNXReduceSumSquareOp.
|
||||
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
|
|
|
@ -24,4 +24,68 @@ include "dialect/onnx/onnx.td"
|
|||
/// dag benefitsAdded = (addBenefit 0)
|
||||
/// >;
|
||||
|
||||
// Create a StringAttr from a string.
|
||||
class StringAttrOfValue<string val>:
|
||||
NativeCodeCall<"$_builder.getStringAttr(\"" # val # "\")">;
|
||||
|
||||
// Create a FloatAttr from an interger value.
|
||||
// It seems Table-gen does not support `float` type, so we can not pass a float value.
|
||||
class FloatAttrOfValue<int val>:
|
||||
NativeCodeCall<"FloatAttr::get($0.getType().cast<TensorType>().getElementType(), " # val # ")">;
|
||||
|
||||
// Create an ArrayAttr of IntergerAttr(s) of zero values.
|
||||
// This function is used for padding attribute in MaxPoolSingleOut.
|
||||
def createArrayAttrOfZerosFrom:
|
||||
NativeCodeCall<"createArrayAttrOfZeros($_builder, $0)">;
|
||||
|
||||
// Pad a ArrayAttr with zeros.
|
||||
//
|
||||
// pads = [B1, B2, ... Bk, E1, E2, ..., Ek]
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// pads = [0,... 0, B1, B2, ... Bk, 0,... 0, E1, E2, ..., Ek]
|
||||
// |_____| |_____|
|
||||
// nZeros nZeros
|
||||
//
|
||||
// This function is used for padding attribute in MaxPoolSingleOut.
|
||||
class insertZerosForNonPaddedDims<int extensionLength>:
|
||||
NativeCodeCall<"insertZerosForNonPaddedDims($_builder, $0,"
|
||||
# extensionLength # ")">;
|
||||
|
||||
// Check whether an ArrayAttr contains non-zero values or not.
|
||||
def HasNonZeroInArrayAttr: Constraint<CPred<"hasNonZeroInArrayAttr($_self)">,
|
||||
"has non-zero elements">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Rewrite:
|
||||
// %0 = onnx.MaxPoolSingleOutOp(%D : tensor<DShape>)
|
||||
// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
|
||||
// tensor<OutShape>
|
||||
//
|
||||
// as:
|
||||
// %0 = onnx.PadConstantValuePadOp(%D)
|
||||
// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
|
||||
// tensor<DPaddedShape>
|
||||
// %1 = onnx.MaxPoolSingleOut(%0 : tensor<DPaddedShape>) {pads = [0, ..., 0]} ->
|
||||
// tensor<OutShape>
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MaxPoolSingleOutOpPaddingPattern: Pat<
|
||||
(ONNXMaxPoolSingleOutOp:$res
|
||||
$x,
|
||||
$auto_pad, $ceil_mode, $dilation, $kernel_shape,
|
||||
$pads,
|
||||
$storage_order, $strides),
|
||||
(ONNXMaxPoolSingleOutOp
|
||||
(ONNXPadConstantValuePadOp $x,
|
||||
(insertZerosForNonPaddedDims<2> $pads),
|
||||
(FloatAttrOfValue<0> $res),
|
||||
(StringAttrOfValue<"constant">)),
|
||||
$auto_pad, $ceil_mode, $dilation, $kernel_shape,
|
||||
(createArrayAttrOfZerosFrom $pads),
|
||||
$storage_order, $strides),
|
||||
[(HasNonZeroInArrayAttr:$pads)]
|
||||
>;
|
||||
|
||||
#endif // ONNX_REWRITE
|
||||
|
|
|
@ -77,3 +77,23 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
|
|||
// return [[GEMM]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
//CHECK-LABEL: @test_maxpoolsingleout_split(%{{.*}}: tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32> {
|
||||
func @test_maxpoolsingleout_split(%arg0: tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32> {
|
||||
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0, kernel_shape = [5,3], pads = [1, 2, 3, 4] } : (tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32>
|
||||
"std.return"(%0) : (tensor<5x8x32x39xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 1, 2, 0, 0, 3, 4]} : (tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32>
|
||||
// CHECK-NEXT: %1 = "onnx.MaxPoolSingleOut"(%0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [0, 0, 0, 0], storage_order = 0 : i64} : (tensor<5x8x32x39xf32>) -> tensor<5x8x32x39xf32>
|
||||
// CHECK-NEXT: return %1 : tensor<5x8x32x39xf32>
|
||||
}
|
||||
|
||||
//CHECK-LABEL: @test_maxpoolsingleout_split_unknown_dims(%{{.*}}: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @test_maxpoolsingleout_split_unknown_dims(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0, kernel_shape = [5,3], pads = [1, 2, 3, 4] } : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 1, 2, 0, 0, 3, 4]} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %1 = "onnx.MaxPoolSingleOut"(%0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [0, 0, 0, 0], storage_order = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return %1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue