Dropout elimination & Conv Bugfix (#297)
* Dropout elimination. * Test VGG19. * Add shufflenet. * Fix grouped convolution bug. * Fix lit test failure.
This commit is contained in:
parent
03dae57189
commit
5e11429d77
|
@ -135,8 +135,8 @@ struct ONNXConvOpLowering : public ConversionPattern {
|
||||||
/*mIndex=*/rewriter.getAffineDimExpr(1));
|
/*mIndex=*/rewriter.getAffineDimExpr(1));
|
||||||
kernel = rewriter.create<AffineApplyOp>(loc, kernelMap,
|
kernel = rewriter.create<AffineApplyOp>(loc, kernelMap,
|
||||||
ArrayRef<Value>{/*gIndex=*/outerLoops.getInductionVar(gIndex),
|
ArrayRef<Value>{/*gIndex=*/outerLoops.getInductionVar(gIndex),
|
||||||
/*kernelsPerGroupValue=*/kernelsPerGroupValue,
|
/*mIndex=*/outerLoops.getInductionVar(mIndex),
|
||||||
/*mIndex=*/outerLoops.getInductionVar(mIndex)});
|
/*kernelsPerGroupValue=*/kernelsPerGroupValue});
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2.2 Define spatial loops
|
// 2.2 Define spatial loops
|
||||||
|
|
|
@ -1046,6 +1046,7 @@ def ONNXDivOp:ONNX_Op<"Div",
|
||||||
|
|
||||||
def ONNXDropoutOp:ONNX_Op<"Dropout",
|
def ONNXDropoutOp:ONNX_Op<"Dropout",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX Dropout operation";
|
let summary = "ONNX Dropout operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Dropout takes one input floating tensor and produces two tensor outputs,"
|
"Dropout takes one input floating tensor and produces two tensor outputs,"
|
||||||
|
|
|
@ -95,3 +95,9 @@ void ONNXTransposeOp::getCanonicalizationPatterns(
|
||||||
result.insert<FuseTransposePattern>(context);
|
result.insert<FuseTransposePattern>(context);
|
||||||
result.insert<RemoveIdentityTransposePattern>(context);
|
result.insert<RemoveIdentityTransposePattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// on the ONNXDropoutOp.
|
||||||
|
void ONNXDropoutOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &result, MLIRContext *context) {
|
||||||
|
result.insert<DropoutEliminationPattern>(context);
|
||||||
|
}
|
||||||
|
|
|
@ -24,6 +24,11 @@ include "src/Dialect/ONNX/ONNXOps.td"
|
||||||
/// dag benefitsAdded = (addBenefit 0)
|
/// dag benefitsAdded = (addBenefit 0)
|
||||||
/// >;
|
/// >;
|
||||||
|
|
||||||
|
// Usefult code generation invokation.
|
||||||
|
def GetNullAttr : NativeCodeCall<"Attribute()">;
|
||||||
|
|
||||||
|
def GetUnitAttr: NativeCodeCall<"$_builder.getUnitAttr()">;
|
||||||
|
|
||||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||||
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
||||||
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
|
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
|
||||||
|
@ -55,11 +60,15 @@ def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none
|
||||||
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
||||||
(replaceWithValue $arg)>;
|
(replaceWithValue $arg)>;
|
||||||
|
|
||||||
|
// y, mask = onnx.Dropout(x) -> y, mask = x, none
|
||||||
|
def DropoutEliminationPattern : Pattern<(ONNXDropoutOp $arg, $ratio),
|
||||||
|
[(replaceWithValue $arg),
|
||||||
|
(ONNXConstantOp (GetNullAttr), (GetUnitAttr))]>;
|
||||||
|
|
||||||
def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3),
|
def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3),
|
||||||
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
|
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
|
||||||
[(HasOneUse $res)]>;
|
[(HasOneUse $res)]>;
|
||||||
|
|
||||||
|
|
||||||
// ONNX_Op (onnx.Cast (%X, $type)) = ONNX_Op (%X)
|
// ONNX_Op (onnx.Cast (%X, $type)) = ONNX_Op (%X)
|
||||||
def CastEliminationPattern : Pat<
|
def CastEliminationPattern : Pat<
|
||||||
(ONNXCastOp $arg, $type),
|
(ONNXCastOp $arg, $type),
|
||||||
|
|
|
@ -415,8 +415,10 @@ test_to_enable = [
|
||||||
"test_split_variable_parts_2d_cpu",
|
"test_split_variable_parts_2d_cpu",
|
||||||
"test_split_variable_parts_default_axis_cpu",
|
"test_split_variable_parts_default_axis_cpu",
|
||||||
|
|
||||||
# ResNet
|
# Model
|
||||||
"test_resnet50_cpu",
|
"test_resnet50_cpu",
|
||||||
|
"test_vgg19_cpu",
|
||||||
|
"test_shufflenet_cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1272,13 +1272,13 @@ func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : te
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_no_pad_w_group
|
// CHECK-LABEL: test_conv_no_bias_no_pad_w_group
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<1x5x27x58xf32>
|
// CHECK: [[RES:%.+]] = alloc() : memref<1x5x27x58xf32>
|
||||||
// CHECK: [[CONST0:%.+]] = constant 1 : index
|
// CHECK: %[[CONST0:.+]] = constant 1 : index
|
||||||
// CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32
|
// CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32
|
||||||
// CHECK: [[CONST2:%.+]] = constant 3 : index
|
// CHECK: [[CONST2:%.+]] = constant 3 : index
|
||||||
// CHECK: [[OUTER_LOOPS:%.+]]:3 = krnl.define_loops 3
|
// CHECK: [[OUTER_LOOPS:%.+]]:3 = krnl.define_loops 3
|
||||||
|
|
||||||
// CHECK: krnl.iterate([[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1, [[OUTER_LOOPS]]#2) with ([[OUTER_LOOPS]]#0 -> %arg2 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg3 = 0 to 3, [[OUTER_LOOPS]]#2 -> %arg4 = 0 to 1) {
|
// CHECK: krnl.iterate([[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1, [[OUTER_LOOPS]]#2) with ([[OUTER_LOOPS]]#0 -> %arg2 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg3 = 0 to 3, [[OUTER_LOOPS]]#2 -> %arg4 = 0 to 1) {
|
||||||
// CHECK: %[[ADD1:.+]] = affine.apply #{{.*}}(%arg3, [[CONST0]])[%arg4]
|
// CHECK: %[[ADD1:.+]] = affine.apply #{{.*}}(%arg3, %arg4)[%[[CONST0]]]
|
||||||
// CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2
|
// CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
|
||||||
// CHECK: krnl.iterate([[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg5 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg6 = 0 to 58) {
|
// CHECK: krnl.iterate([[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg5 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg6 = 0 to 58) {
|
||||||
|
|
|
@ -321,7 +321,7 @@ OpsWithShapeInference=[
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose']
|
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose', 'Dropout']
|
||||||
|
|
||||||
# Operations who have operands that, if produced by constant operations, should
|
# Operations who have operands that, if produced by constant operations, should
|
||||||
# be promoted to become an attribute (via attribute promotion).
|
# be promoted to become an attribute (via attribute promotion).
|
||||||
|
|
Loading…
Reference in New Issue