clean up gen_doc.py (#59)

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
chentong319 2020-01-29 13:54:46 -05:00 committed by GitHub
parent ce8594fc60
commit 6959cf4586
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 40 deletions

View File

@ -18,6 +18,44 @@ from onnx.backend.sample.ops import collect_sample_implementations
from typing import Any, Text, Sequence, Dict, List, Type, Set, Tuple
#controls on ONNF code gen
#specify attr default value
special_attr_defaults = dict([
# ("AveragePool "+"kernel_shape", ('ints', '{}')),
# ("MaxPool "+"kernel_shape", ('ints', '{}')),
# ("Cast "+"to", ('int', '0')),
# ("Concat "+"axis", ('int', '0')),
# ("Conv "+"group", ('int', '1')),
# ("Unsqueeze "+"axes", ('ints', '{}')),
# ("RNN "+"activation_alpha", ('floats', '{}')),
# ("RNN "+"activation_beta", ('floats', '{}')),
])
#specify the function name in src/builder/frontend_dialect_transformer.cpp
#the reason for Conv and MaPool is to handled optional arguments
special_op_handler = dict([
("Conv", "ImportNodeConv"),
("MaxPool", "ImportNodeMaxPool"),
#("Transpose", "ImportNodeTranspose")
])
#add an Op in this list if ShapeInterference is defined for this Op
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze']
CanonicalList=['Add', 'Identity']
manual_code_in_op_def = dict([
('DummyExample', ' let extraClassDeclaration = [{ \n'+
' static StringRef getPermAttrName() { return "perm"; }\n'+
' }];\n')
])
SNIPPETS = collect_snippets()
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
ONNX_ML = not bool(os.getenv('ONNX_ML') == '0')
@ -263,18 +301,6 @@ def collect_types(schema, input) :
return allowedTypeStr
def gen_schema(schema) :
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze']
CanonicalList=['Add', 'Identity']
manual_code = dict([
('DummyExample', ' let extraClassDeclaration = [{ \n'+
' static StringRef getPermAttrName() { return "perm"; }\n'+
' }];\n')
])
skip_attr_gen = []
line_indent = ' '
@ -362,8 +388,8 @@ def gen_schema(schema) :
#s+= 'let hasCanonicalizer = 1;'
#add special code
if schema.name in manual_code :
s += manual_code[schema.name]
if schema.name in manual_code_in_op_def :
s += manual_code_in_op_def[schema.name]
s += '}\n\n'
@ -378,19 +404,14 @@ special cases:
"""
def gen_code(schema,fefile) :
special_handler = dict([
("Conv", "ImportNodeConv"),
("MaxPool", "ImportNodeMaxPool"),
#("Transpose", "ImportNodeTranspose")
])
handle_variadic = False
line_indent = ' '
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
op_type_str='mlir::ONNX'+schema.name+'Op'
if schema.name in special_handler :
fefile.write(' '+special_handler[schema.name]+'(node, '
if schema.name in special_op_handler :
fefile.write(' '+special_op_handler[schema.name]+'(node, '
+str(len(schema.inputs))
+', ' +str(len(schema.outputs)))
elif len(schema.outputs) > 1 :
@ -419,16 +440,6 @@ def gen_code(schema,fefile) :
fefile.write(', '+variadicIn+', '+variadicOut+');\n')
def gen_attr_ins(schema, isfirst) :
special_defaults = dict([
("AveragePool "+"kernel_shape", ('ints', '{}')),
("MaxPool "+"kernel_shape", ('ints', '{}')),
("Cast "+"to", ('int', '0')),
("Concat "+"axis", ('int', '0')),
("Conv "+"group", ('int', '1')),
("Unsqueeze "+"axes", ('ints', '{}')),
("RNN "+"activation_alpha", ('floats', '{}')),
("RNN "+"activation_beta", ('floats', '{}')),
])
def get_attr_type_basic(attr_type) :
if attr_type == 'int' :
@ -469,8 +480,8 @@ def gen_attr_ins(schema, isfirst) :
else :
isfirst = False
if schema.name+' '+attr.name in special_defaults:
(attr_type_str, attr_default_str) = special_defaults[schema.name+' '+attr.name]
if schema.name+' '+attr.name in special_attr_defaults:
(attr_type_str, attr_default_str) = special_attr_defaults[schema.name+' '+attr.name]
attr_line += get_attr_type_with_default(attr_type_str, attr_default_str)
attr_line += ':$'+attr.name
elif attr.required:
@ -614,9 +625,21 @@ def main(args): # type: (Type[Args]) -> None
fout.write('\n')
tdfile= io.open(args.tdfile, 'w', newline='')
tdfile.write('//********************************************************\n'+
'// Warning: Do not modify this file directly\n'+
'// This file is automatically generated via script\n'+
'// Details can be found in doc/readonnxdefs.md\n'+
'//********************************************************\n\n'
)
fefile=io.open('op_build_table.inc', 'w', newline='')
firstfunc = True
fefile.write('//********************************************************\n'+
'// Warning: Do not modify this file directly\n'+
'// This file is automatically generated via script\n'+
'// Details can be found in doc/readonnxdefs.md\n'+
'//********************************************************\n\n'
)
fefile.write(' '+'if (OpName == "DUMMY") {\n')
for domain, supportmap in operator_schemas:
s = '## {}\n'.format(display_domain_short(domain))

View File

@ -1,3 +1,9 @@
//********************************************************
// Warning: Do not modify this file directly
// This file is automatically generated via script
// Details can be found in doc/readonnxdefs.md
//********************************************************
if (OpName == "DUMMY") {
}else if (OpName == "Abs") {
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1);

View File

@ -1,3 +1,9 @@
//********************************************************
// Warning: Do not modify this file directly
// This file is automatically generated via script
// Details can be found in doc/readonnxdefs.md
//********************************************************
def ONNXAbsOp:ONNX_Op<"Abs",
[NoSideEffect]> {
let summary = "ONNX Abs operation";
@ -166,7 +172,7 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool",
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
DefaultValuedAttr<I64Attr, "0">:$ceil_mode,
DefaultValuedAttr<I64Attr, "0">:$count_include_pad,
DefaultValuedAttr<I64ArrayAttr, "{}">:$kernel_shape,
I64ArrayAttr:$kernel_shape,
OptionalAttr<I64ArrayAttr>:$pads,
OptionalAttr<I64ArrayAttr>:$strides);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
@ -245,7 +251,7 @@ def ONNXCastOp:ONNX_Op<"Cast",
"an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input,
DefaultValuedAttr<I64Attr, "0">:$to);
I64Attr:$to);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_output);
}
@ -297,7 +303,7 @@ def ONNXConcatOp:ONNX_Op<"Concat",
"Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on."
}];
let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$inputs,
DefaultValuedAttr<I64Attr, "0">:$axis);
I64Attr:$axis);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_concat_result);
}
@ -1621,7 +1627,7 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool",
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
DefaultValuedAttr<I64Attr, "0">:$ceil_mode,
OptionalAttr<I64ArrayAttr>:$dilations,
DefaultValuedAttr<I64ArrayAttr, "{}">:$kernel_shape,
I64ArrayAttr:$kernel_shape,
OptionalAttr<I64ArrayAttr>:$pads,
DefaultValuedAttr<I64Attr, "0">:$storage_order,
OptionalAttr<I64ArrayAttr>:$strides);
@ -2122,8 +2128,8 @@ def ONNXRNNOp:ONNX_Op<"RNN",
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h,
DefaultValuedAttr<F32ArrayAttr, "{}">:$activation_alpha,
DefaultValuedAttr<F32ArrayAttr, "{}">:$activation_beta,
OptionalAttr<F32ArrayAttr>:$activation_alpha,
OptionalAttr<F32ArrayAttr>:$activation_beta,
DefaultValuedAttr<StrArrayAttr, "{\"Tanh\", \"Tanh\"}">:$activations,
OptionalAttr<F32Attr>:$clip,
DefaultValuedAttr<StrAttr, "forward">:$direction,
@ -3519,7 +3525,7 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
""
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
DefaultValuedAttr<I64ArrayAttr, "{}">:$axes);
I64ArrayAttr:$axes);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_expanded);
}