Dropout elimination & Conv Bugfix (#297)

* Dropout elimination.

* Test VGG19.

* Add shufflenet.

* Fix grouped convolution bug.

* Fix lit test failure.
This commit is contained in:
Tian Jin 2020-09-10 14:47:30 +08:00 committed by GitHub
parent 03dae57189
commit 5e11429d77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 34 additions and 16 deletions

View File

@ -135,8 +135,8 @@ struct ONNXConvOpLowering : public ConversionPattern {
/*mIndex=*/rewriter.getAffineDimExpr(1));
kernel = rewriter.create<AffineApplyOp>(loc, kernelMap,
ArrayRef<Value>{/*gIndex=*/outerLoops.getInductionVar(gIndex),
/*kernelsPerGroupValue=*/kernelsPerGroupValue,
/*mIndex=*/outerLoops.getInductionVar(mIndex)});
/*mIndex=*/outerLoops.getInductionVar(mIndex),
/*kernelsPerGroupValue=*/kernelsPerGroupValue});
}
// 2.2 Define spatial loops

View File

@ -1046,6 +1046,7 @@ def ONNXDivOp:ONNX_Op<"Div",
def ONNXDropoutOp:ONNX_Op<"Dropout",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX Dropout operation";
let description = [{
"Dropout takes one input floating tensor and produces two tensor outputs,"

View File

@ -95,3 +95,9 @@ void ONNXTransposeOp::getCanonicalizationPatterns(
result.insert<FuseTransposePattern>(context);
result.insert<RemoveIdentityTransposePattern>(context);
}
/// on the ONNXDropoutOp.
void ONNXDropoutOp::getCanonicalizationPatterns(
OwningRewritePatternList &result, MLIRContext *context) {
result.insert<DropoutEliminationPattern>(context);
}

View File

@ -24,6 +24,11 @@ include "src/Dialect/ONNX/ONNXOps.td"
/// dag benefitsAdded = (addBenefit 0)
/// >;
// Usefult code generation invokation.
def GetNullAttr : NativeCodeCall<"Attribute()">;
def GetUnitAttr: NativeCodeCall<"$_builder.getUnitAttr()">;
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
@ -55,11 +60,15 @@ def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
(replaceWithValue $arg)>;
// y, mask = onnx.Dropout(x) -> y, mask = x, none
def DropoutEliminationPattern : Pattern<(ONNXDropoutOp $arg, $ratio),
[(replaceWithValue $arg),
(ONNXConstantOp (GetNullAttr), (GetUnitAttr))]>;
def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3),
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
[(HasOneUse $res)]>;
// ONNX_Op (onnx.Cast (%X, $type)) = ONNX_Op (%X)
def CastEliminationPattern : Pat<
(ONNXCastOp $arg, $type),

View File

@ -415,8 +415,10 @@ test_to_enable = [
"test_split_variable_parts_2d_cpu",
"test_split_variable_parts_default_axis_cpu",
# ResNet
# Model
"test_resnet50_cpu",
"test_vgg19_cpu",
"test_shufflenet_cpu",
]

View File

@ -1272,13 +1272,13 @@ func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : te
// CHECK-LABEL: test_conv_no_bias_no_pad_w_group
// CHECK: [[RES:%.+]] = alloc() : memref<1x5x27x58xf32>
// CHECK: [[CONST0:%.+]] = constant 1 : index
// CHECK: %[[CONST0:.+]] = constant 1 : index
// CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[CONST2:%.+]] = constant 3 : index
// CHECK: [[OUTER_LOOPS:%.+]]:3 = krnl.define_loops 3
// CHECK: krnl.iterate([[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1, [[OUTER_LOOPS]]#2) with ([[OUTER_LOOPS]]#0 -> %arg2 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg3 = 0 to 3, [[OUTER_LOOPS]]#2 -> %arg4 = 0 to 1) {
// CHECK: %[[ADD1:.+]] = affine.apply #{{.*}}(%arg3, [[CONST0]])[%arg4]
// CHECK: %[[ADD1:.+]] = affine.apply #{{.*}}(%arg3, %arg4)[%[[CONST0]]]
// CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: krnl.iterate([[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg5 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg6 = 0 to 58) {

View File

@ -32,7 +32,7 @@ parser.add_argument("--dry-run-op-build-table",
action="store_true",
default=False)
parser.add_argument("--check-operation-version",
help="check whether the imported onnx package has new operation or "
help="check whether the imported onnx package has new operation or "
" newer version of operation compared with version stored in version_dicts",
action="store_true",
default=False)
@ -321,7 +321,7 @@ OpsWithShapeInference=[
]
# Operations supporting canonicalization.
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose']
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose', 'Dropout']
# Operations who have operands that, if produced by constant operations, should
# be promoted to become an attribute (via attribute promotion).
@ -496,7 +496,7 @@ def get_tblgen_type_index(type_str):
return tblgen_types.index(type_str)
#the possible data structures are tensor, map and seq(tensor())
def get_data_structure_element(allowed_type_str):
def get_data_structure_element(allowed_type_str):
structure_list = ['tensor', 'seq', 'map']
for structure in structure_list:
if allowed_type_str.startswith(structure) :
@ -542,9 +542,9 @@ def get_allowed_elem_types(schema, input):
return allowed_structure, None
if not t in allowed_type_list :
allowed_tyoe_list = allowed_type_list.append(t)
return allowed_structure,allowed_type_list
return None, None
@ -812,7 +812,7 @@ def parse_type_str(allowedType):
'float16' : 'F16',
'float' : 'F32',
'double' : 'F64',
'unkown' : 'BF16',
'unkown' : 'BF16',
'complex64' : 'Complex<F32>',
'complex128' : 'Complex<F64>',
'string' : 'StringType'}
@ -820,14 +820,14 @@ def parse_type_str(allowedType):
for key, item in onnx_to_mlir_type_dict.items():
allowedType = allowedType.replace(key, item)
return allowedType
def parse_a_type_constraint(constraint):
allowedTypes = constraint.allowed_type_strs
mlirTypes = []
for allowedType in allowedTypes:
mlirType = parse_type_str(allowedType)
mlirTypes.append(mlirType)
# Remove redundant and sort.
# Remove redundant and sort.
# However onnx keeps a consitently meaningful order
# There is no redundancy as long as each onnx type is mapped uniquely
# mlirTypes = sorted(list(set(mlirTypes)))
@ -905,7 +905,7 @@ def gen_op_def(schema):
(',\n' + inc_indent(indent)).join(outs_strs))
# custom_builder_broadcast_ops_list
# add custom builders
# use element type of the first operand to construct an UnrankedTensorType for the output.
if schema.name in custom_builder_ops_list:
@ -973,7 +973,7 @@ def gen_op_def(schema):
'.getType().cast<TensorType>().getElementType();\n';
s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n'
s += indent + '}\n';
else:
else:
s += indent + 'auto elementType = operands[0].getType().' + \
'cast<TensorType>().getElementType();\n'
s += indent + 'std::vector<mlir::Type> outputTypes;\n'