Dropout elimination & Conv Bugfix (#297)
* Dropout elimination. * Test VGG19. * Add shufflenet. * Fix grouped convolution bug. * Fix lit test failure.
This commit is contained in:
parent
03dae57189
commit
5e11429d77
|
@ -135,8 +135,8 @@ struct ONNXConvOpLowering : public ConversionPattern {
|
||||||
/*mIndex=*/rewriter.getAffineDimExpr(1));
|
/*mIndex=*/rewriter.getAffineDimExpr(1));
|
||||||
kernel = rewriter.create<AffineApplyOp>(loc, kernelMap,
|
kernel = rewriter.create<AffineApplyOp>(loc, kernelMap,
|
||||||
ArrayRef<Value>{/*gIndex=*/outerLoops.getInductionVar(gIndex),
|
ArrayRef<Value>{/*gIndex=*/outerLoops.getInductionVar(gIndex),
|
||||||
/*kernelsPerGroupValue=*/kernelsPerGroupValue,
|
/*mIndex=*/outerLoops.getInductionVar(mIndex),
|
||||||
/*mIndex=*/outerLoops.getInductionVar(mIndex)});
|
/*kernelsPerGroupValue=*/kernelsPerGroupValue});
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2.2 Define spatial loops
|
// 2.2 Define spatial loops
|
||||||
|
|
|
@ -1046,6 +1046,7 @@ def ONNXDivOp:ONNX_Op<"Div",
|
||||||
|
|
||||||
def ONNXDropoutOp:ONNX_Op<"Dropout",
|
def ONNXDropoutOp:ONNX_Op<"Dropout",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX Dropout operation";
|
let summary = "ONNX Dropout operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Dropout takes one input floating tensor and produces two tensor outputs,"
|
"Dropout takes one input floating tensor and produces two tensor outputs,"
|
||||||
|
|
|
@ -95,3 +95,9 @@ void ONNXTransposeOp::getCanonicalizationPatterns(
|
||||||
result.insert<FuseTransposePattern>(context);
|
result.insert<FuseTransposePattern>(context);
|
||||||
result.insert<RemoveIdentityTransposePattern>(context);
|
result.insert<RemoveIdentityTransposePattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// on the ONNXDropoutOp.
|
||||||
|
void ONNXDropoutOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &result, MLIRContext *context) {
|
||||||
|
result.insert<DropoutEliminationPattern>(context);
|
||||||
|
}
|
||||||
|
|
|
@ -24,6 +24,11 @@ include "src/Dialect/ONNX/ONNXOps.td"
|
||||||
/// dag benefitsAdded = (addBenefit 0)
|
/// dag benefitsAdded = (addBenefit 0)
|
||||||
/// >;
|
/// >;
|
||||||
|
|
||||||
|
// Usefult code generation invokation.
|
||||||
|
def GetNullAttr : NativeCodeCall<"Attribute()">;
|
||||||
|
|
||||||
|
def GetUnitAttr: NativeCodeCall<"$_builder.getUnitAttr()">;
|
||||||
|
|
||||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||||
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
||||||
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
|
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
|
||||||
|
@ -55,11 +60,15 @@ def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none
|
||||||
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
||||||
(replaceWithValue $arg)>;
|
(replaceWithValue $arg)>;
|
||||||
|
|
||||||
|
// y, mask = onnx.Dropout(x) -> y, mask = x, none
|
||||||
|
def DropoutEliminationPattern : Pattern<(ONNXDropoutOp $arg, $ratio),
|
||||||
|
[(replaceWithValue $arg),
|
||||||
|
(ONNXConstantOp (GetNullAttr), (GetUnitAttr))]>;
|
||||||
|
|
||||||
def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3),
|
def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3),
|
||||||
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
|
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
|
||||||
[(HasOneUse $res)]>;
|
[(HasOneUse $res)]>;
|
||||||
|
|
||||||
|
|
||||||
// ONNX_Op (onnx.Cast (%X, $type)) = ONNX_Op (%X)
|
// ONNX_Op (onnx.Cast (%X, $type)) = ONNX_Op (%X)
|
||||||
def CastEliminationPattern : Pat<
|
def CastEliminationPattern : Pat<
|
||||||
(ONNXCastOp $arg, $type),
|
(ONNXCastOp $arg, $type),
|
||||||
|
|
|
@ -415,8 +415,10 @@ test_to_enable = [
|
||||||
"test_split_variable_parts_2d_cpu",
|
"test_split_variable_parts_2d_cpu",
|
||||||
"test_split_variable_parts_default_axis_cpu",
|
"test_split_variable_parts_default_axis_cpu",
|
||||||
|
|
||||||
# ResNet
|
# Model
|
||||||
"test_resnet50_cpu",
|
"test_resnet50_cpu",
|
||||||
|
"test_vgg19_cpu",
|
||||||
|
"test_shufflenet_cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1272,13 +1272,13 @@ func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : te
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_no_pad_w_group
|
// CHECK-LABEL: test_conv_no_bias_no_pad_w_group
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<1x5x27x58xf32>
|
// CHECK: [[RES:%.+]] = alloc() : memref<1x5x27x58xf32>
|
||||||
// CHECK: [[CONST0:%.+]] = constant 1 : index
|
// CHECK: %[[CONST0:.+]] = constant 1 : index
|
||||||
// CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32
|
// CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32
|
||||||
// CHECK: [[CONST2:%.+]] = constant 3 : index
|
// CHECK: [[CONST2:%.+]] = constant 3 : index
|
||||||
// CHECK: [[OUTER_LOOPS:%.+]]:3 = krnl.define_loops 3
|
// CHECK: [[OUTER_LOOPS:%.+]]:3 = krnl.define_loops 3
|
||||||
|
|
||||||
// CHECK: krnl.iterate([[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1, [[OUTER_LOOPS]]#2) with ([[OUTER_LOOPS]]#0 -> %arg2 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg3 = 0 to 3, [[OUTER_LOOPS]]#2 -> %arg4 = 0 to 1) {
|
// CHECK: krnl.iterate([[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1, [[OUTER_LOOPS]]#2) with ([[OUTER_LOOPS]]#0 -> %arg2 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg3 = 0 to 3, [[OUTER_LOOPS]]#2 -> %arg4 = 0 to 1) {
|
||||||
// CHECK: %[[ADD1:.+]] = affine.apply #{{.*}}(%arg3, [[CONST0]])[%arg4]
|
// CHECK: %[[ADD1:.+]] = affine.apply #{{.*}}(%arg3, %arg4)[%[[CONST0]]]
|
||||||
// CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2
|
// CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
|
||||||
// CHECK: krnl.iterate([[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg5 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg6 = 0 to 58) {
|
// CHECK: krnl.iterate([[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg5 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg6 = 0 to 58) {
|
||||||
|
|
|
@ -32,7 +32,7 @@ parser.add_argument("--dry-run-op-build-table",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False)
|
default=False)
|
||||||
parser.add_argument("--check-operation-version",
|
parser.add_argument("--check-operation-version",
|
||||||
help="check whether the imported onnx package has new operation or "
|
help="check whether the imported onnx package has new operation or "
|
||||||
" newer version of operation compared with version stored in version_dicts",
|
" newer version of operation compared with version stored in version_dicts",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False)
|
default=False)
|
||||||
|
@ -321,7 +321,7 @@ OpsWithShapeInference=[
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose']
|
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose', 'Dropout']
|
||||||
|
|
||||||
# Operations who have operands that, if produced by constant operations, should
|
# Operations who have operands that, if produced by constant operations, should
|
||||||
# be promoted to become an attribute (via attribute promotion).
|
# be promoted to become an attribute (via attribute promotion).
|
||||||
|
@ -496,7 +496,7 @@ def get_tblgen_type_index(type_str):
|
||||||
return tblgen_types.index(type_str)
|
return tblgen_types.index(type_str)
|
||||||
|
|
||||||
#the possible data structures are tensor, map and seq(tensor())
|
#the possible data structures are tensor, map and seq(tensor())
|
||||||
def get_data_structure_element(allowed_type_str):
|
def get_data_structure_element(allowed_type_str):
|
||||||
structure_list = ['tensor', 'seq', 'map']
|
structure_list = ['tensor', 'seq', 'map']
|
||||||
for structure in structure_list:
|
for structure in structure_list:
|
||||||
if allowed_type_str.startswith(structure) :
|
if allowed_type_str.startswith(structure) :
|
||||||
|
@ -542,9 +542,9 @@ def get_allowed_elem_types(schema, input):
|
||||||
return allowed_structure, None
|
return allowed_structure, None
|
||||||
if not t in allowed_type_list :
|
if not t in allowed_type_list :
|
||||||
allowed_tyoe_list = allowed_type_list.append(t)
|
allowed_tyoe_list = allowed_type_list.append(t)
|
||||||
|
|
||||||
return allowed_structure,allowed_type_list
|
return allowed_structure,allowed_type_list
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -812,7 +812,7 @@ def parse_type_str(allowedType):
|
||||||
'float16' : 'F16',
|
'float16' : 'F16',
|
||||||
'float' : 'F32',
|
'float' : 'F32',
|
||||||
'double' : 'F64',
|
'double' : 'F64',
|
||||||
'unkown' : 'BF16',
|
'unkown' : 'BF16',
|
||||||
'complex64' : 'Complex<F32>',
|
'complex64' : 'Complex<F32>',
|
||||||
'complex128' : 'Complex<F64>',
|
'complex128' : 'Complex<F64>',
|
||||||
'string' : 'StringType'}
|
'string' : 'StringType'}
|
||||||
|
@ -820,14 +820,14 @@ def parse_type_str(allowedType):
|
||||||
for key, item in onnx_to_mlir_type_dict.items():
|
for key, item in onnx_to_mlir_type_dict.items():
|
||||||
allowedType = allowedType.replace(key, item)
|
allowedType = allowedType.replace(key, item)
|
||||||
return allowedType
|
return allowedType
|
||||||
|
|
||||||
def parse_a_type_constraint(constraint):
|
def parse_a_type_constraint(constraint):
|
||||||
allowedTypes = constraint.allowed_type_strs
|
allowedTypes = constraint.allowed_type_strs
|
||||||
mlirTypes = []
|
mlirTypes = []
|
||||||
for allowedType in allowedTypes:
|
for allowedType in allowedTypes:
|
||||||
mlirType = parse_type_str(allowedType)
|
mlirType = parse_type_str(allowedType)
|
||||||
mlirTypes.append(mlirType)
|
mlirTypes.append(mlirType)
|
||||||
# Remove redundant and sort.
|
# Remove redundant and sort.
|
||||||
# However onnx keeps a consitently meaningful order
|
# However onnx keeps a consitently meaningful order
|
||||||
# There is no redundancy as long as each onnx type is mapped uniquely
|
# There is no redundancy as long as each onnx type is mapped uniquely
|
||||||
# mlirTypes = sorted(list(set(mlirTypes)))
|
# mlirTypes = sorted(list(set(mlirTypes)))
|
||||||
|
@ -905,7 +905,7 @@ def gen_op_def(schema):
|
||||||
(',\n' + inc_indent(indent)).join(outs_strs))
|
(',\n' + inc_indent(indent)).join(outs_strs))
|
||||||
|
|
||||||
# custom_builder_broadcast_ops_list
|
# custom_builder_broadcast_ops_list
|
||||||
|
|
||||||
# add custom builders
|
# add custom builders
|
||||||
# use element type of the first operand to construct an UnrankedTensorType for the output.
|
# use element type of the first operand to construct an UnrankedTensorType for the output.
|
||||||
if schema.name in custom_builder_ops_list:
|
if schema.name in custom_builder_ops_list:
|
||||||
|
@ -973,7 +973,7 @@ def gen_op_def(schema):
|
||||||
'.getType().cast<TensorType>().getElementType();\n';
|
'.getType().cast<TensorType>().getElementType();\n';
|
||||||
s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n'
|
s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n'
|
||||||
s += indent + '}\n';
|
s += indent + '}\n';
|
||||||
else:
|
else:
|
||||||
s += indent + 'auto elementType = operands[0].getType().' + \
|
s += indent + 'auto elementType = operands[0].getType().' + \
|
||||||
'cast<TensorType>().getElementType();\n'
|
'cast<TensorType>().getElementType();\n'
|
||||||
s += indent + 'std::vector<mlir::Type> outputTypes;\n'
|
s += indent + 'std::vector<mlir::Type> outputTypes;\n'
|
||||||
|
|
Loading…
Reference in New Issue