clean up gen_doc.py (#59)
Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
ce8594fc60
commit
6959cf4586
|
@ -18,6 +18,44 @@ from onnx.backend.sample.ops import collect_sample_implementations
|
||||||
from typing import Any, Text, Sequence, Dict, List, Type, Set, Tuple
|
from typing import Any, Text, Sequence, Dict, List, Type, Set, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
#controls on ONNF code gen
|
||||||
|
#specify attr default value
|
||||||
|
special_attr_defaults = dict([
|
||||||
|
# ("AveragePool "+"kernel_shape", ('ints', '{}')),
|
||||||
|
# ("MaxPool "+"kernel_shape", ('ints', '{}')),
|
||||||
|
# ("Cast "+"to", ('int', '0')),
|
||||||
|
# ("Concat "+"axis", ('int', '0')),
|
||||||
|
# ("Conv "+"group", ('int', '1')),
|
||||||
|
# ("Unsqueeze "+"axes", ('ints', '{}')),
|
||||||
|
# ("RNN "+"activation_alpha", ('floats', '{}')),
|
||||||
|
# ("RNN "+"activation_beta", ('floats', '{}')),
|
||||||
|
])
|
||||||
|
|
||||||
|
#specify the function name in src/builder/frontend_dialect_transformer.cpp
|
||||||
|
#the reason for Conv and MaPool is to handled optional arguments
|
||||||
|
special_op_handler = dict([
|
||||||
|
("Conv", "ImportNodeConv"),
|
||||||
|
("MaxPool", "ImportNodeMaxPool"),
|
||||||
|
#("Transpose", "ImportNodeTranspose")
|
||||||
|
])
|
||||||
|
|
||||||
|
#add an Op in this list if ShapeInterference is defined for this Op
|
||||||
|
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
||||||
|
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||||
|
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||||
|
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
||||||
|
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze']
|
||||||
|
|
||||||
|
CanonicalList=['Add', 'Identity']
|
||||||
|
|
||||||
|
manual_code_in_op_def = dict([
|
||||||
|
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
||||||
|
' static StringRef getPermAttrName() { return "perm"; }\n'+
|
||||||
|
' }];\n')
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
SNIPPETS = collect_snippets()
|
SNIPPETS = collect_snippets()
|
||||||
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
|
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
|
||||||
ONNX_ML = not bool(os.getenv('ONNX_ML') == '0')
|
ONNX_ML = not bool(os.getenv('ONNX_ML') == '0')
|
||||||
|
@ -263,18 +301,6 @@ def collect_types(schema, input) :
|
||||||
return allowedTypeStr
|
return allowedTypeStr
|
||||||
|
|
||||||
def gen_schema(schema) :
|
def gen_schema(schema) :
|
||||||
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
|
||||||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
|
||||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
|
||||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
|
||||||
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze']
|
|
||||||
CanonicalList=['Add', 'Identity']
|
|
||||||
manual_code = dict([
|
|
||||||
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
|
||||||
' static StringRef getPermAttrName() { return "perm"; }\n'+
|
|
||||||
' }];\n')
|
|
||||||
])
|
|
||||||
skip_attr_gen = []
|
skip_attr_gen = []
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
|
|
||||||
|
@ -362,8 +388,8 @@ def gen_schema(schema) :
|
||||||
|
|
||||||
#s+= 'let hasCanonicalizer = 1;'
|
#s+= 'let hasCanonicalizer = 1;'
|
||||||
#add special code
|
#add special code
|
||||||
if schema.name in manual_code :
|
if schema.name in manual_code_in_op_def :
|
||||||
s += manual_code[schema.name]
|
s += manual_code_in_op_def[schema.name]
|
||||||
|
|
||||||
s += '}\n\n'
|
s += '}\n\n'
|
||||||
|
|
||||||
|
@ -378,19 +404,14 @@ special cases:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def gen_code(schema,fefile) :
|
def gen_code(schema,fefile) :
|
||||||
special_handler = dict([
|
|
||||||
("Conv", "ImportNodeConv"),
|
|
||||||
("MaxPool", "ImportNodeMaxPool"),
|
|
||||||
#("Transpose", "ImportNodeTranspose")
|
|
||||||
])
|
|
||||||
|
|
||||||
handle_variadic = False
|
handle_variadic = False
|
||||||
|
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
||||||
op_type_str='mlir::ONNX'+schema.name+'Op'
|
op_type_str='mlir::ONNX'+schema.name+'Op'
|
||||||
if schema.name in special_handler :
|
if schema.name in special_op_handler :
|
||||||
fefile.write(' '+special_handler[schema.name]+'(node, '
|
fefile.write(' '+special_op_handler[schema.name]+'(node, '
|
||||||
+str(len(schema.inputs))
|
+str(len(schema.inputs))
|
||||||
+', ' +str(len(schema.outputs)))
|
+', ' +str(len(schema.outputs)))
|
||||||
elif len(schema.outputs) > 1 :
|
elif len(schema.outputs) > 1 :
|
||||||
|
@ -419,16 +440,6 @@ def gen_code(schema,fefile) :
|
||||||
fefile.write(', '+variadicIn+', '+variadicOut+');\n')
|
fefile.write(', '+variadicIn+', '+variadicOut+');\n')
|
||||||
|
|
||||||
def gen_attr_ins(schema, isfirst) :
|
def gen_attr_ins(schema, isfirst) :
|
||||||
special_defaults = dict([
|
|
||||||
("AveragePool "+"kernel_shape", ('ints', '{}')),
|
|
||||||
("MaxPool "+"kernel_shape", ('ints', '{}')),
|
|
||||||
("Cast "+"to", ('int', '0')),
|
|
||||||
("Concat "+"axis", ('int', '0')),
|
|
||||||
("Conv "+"group", ('int', '1')),
|
|
||||||
("Unsqueeze "+"axes", ('ints', '{}')),
|
|
||||||
("RNN "+"activation_alpha", ('floats', '{}')),
|
|
||||||
("RNN "+"activation_beta", ('floats', '{}')),
|
|
||||||
])
|
|
||||||
|
|
||||||
def get_attr_type_basic(attr_type) :
|
def get_attr_type_basic(attr_type) :
|
||||||
if attr_type == 'int' :
|
if attr_type == 'int' :
|
||||||
|
@ -469,8 +480,8 @@ def gen_attr_ins(schema, isfirst) :
|
||||||
else :
|
else :
|
||||||
isfirst = False
|
isfirst = False
|
||||||
|
|
||||||
if schema.name+' '+attr.name in special_defaults:
|
if schema.name+' '+attr.name in special_attr_defaults:
|
||||||
(attr_type_str, attr_default_str) = special_defaults[schema.name+' '+attr.name]
|
(attr_type_str, attr_default_str) = special_attr_defaults[schema.name+' '+attr.name]
|
||||||
attr_line += get_attr_type_with_default(attr_type_str, attr_default_str)
|
attr_line += get_attr_type_with_default(attr_type_str, attr_default_str)
|
||||||
attr_line += ':$'+attr.name
|
attr_line += ':$'+attr.name
|
||||||
elif attr.required:
|
elif attr.required:
|
||||||
|
@ -614,9 +625,21 @@ def main(args): # type: (Type[Args]) -> None
|
||||||
|
|
||||||
fout.write('\n')
|
fout.write('\n')
|
||||||
tdfile= io.open(args.tdfile, 'w', newline='')
|
tdfile= io.open(args.tdfile, 'w', newline='')
|
||||||
|
tdfile.write('//********************************************************\n'+
|
||||||
|
'// Warning: Do not modify this file directly\n'+
|
||||||
|
'// This file is automatically generated via script\n'+
|
||||||
|
'// Details can be found in doc/readonnxdefs.md\n'+
|
||||||
|
'//********************************************************\n\n'
|
||||||
|
)
|
||||||
fefile=io.open('op_build_table.inc', 'w', newline='')
|
fefile=io.open('op_build_table.inc', 'w', newline='')
|
||||||
firstfunc = True
|
firstfunc = True
|
||||||
|
|
||||||
|
fefile.write('//********************************************************\n'+
|
||||||
|
'// Warning: Do not modify this file directly\n'+
|
||||||
|
'// This file is automatically generated via script\n'+
|
||||||
|
'// Details can be found in doc/readonnxdefs.md\n'+
|
||||||
|
'//********************************************************\n\n'
|
||||||
|
)
|
||||||
fefile.write(' '+'if (OpName == "DUMMY") {\n')
|
fefile.write(' '+'if (OpName == "DUMMY") {\n')
|
||||||
for domain, supportmap in operator_schemas:
|
for domain, supportmap in operator_schemas:
|
||||||
s = '## {}\n'.format(display_domain_short(domain))
|
s = '## {}\n'.format(display_domain_short(domain))
|
|
@ -1,3 +1,9 @@
|
||||||
|
//********************************************************
|
||||||
|
// Warning: Do not modify this file directly
|
||||||
|
// This file is automatically generated via script
|
||||||
|
// Details can be found in doc/readonnxdefs.md
|
||||||
|
//********************************************************
|
||||||
|
|
||||||
if (OpName == "DUMMY") {
|
if (OpName == "DUMMY") {
|
||||||
}else if (OpName == "Abs") {
|
}else if (OpName == "Abs") {
|
||||||
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1);
|
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1);
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
//********************************************************
|
||||||
|
// Warning: Do not modify this file directly
|
||||||
|
// This file is automatically generated via script
|
||||||
|
// Details can be found in doc/readonnxdefs.md
|
||||||
|
//********************************************************
|
||||||
|
|
||||||
def ONNXAbsOp:ONNX_Op<"Abs",
|
def ONNXAbsOp:ONNX_Op<"Abs",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
let summary = "ONNX Abs operation";
|
let summary = "ONNX Abs operation";
|
||||||
|
@ -166,7 +172,7 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool",
|
||||||
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
|
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
|
||||||
DefaultValuedAttr<I64Attr, "0">:$ceil_mode,
|
DefaultValuedAttr<I64Attr, "0">:$ceil_mode,
|
||||||
DefaultValuedAttr<I64Attr, "0">:$count_include_pad,
|
DefaultValuedAttr<I64Attr, "0">:$count_include_pad,
|
||||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$kernel_shape,
|
I64ArrayAttr:$kernel_shape,
|
||||||
OptionalAttr<I64ArrayAttr>:$pads,
|
OptionalAttr<I64ArrayAttr>:$pads,
|
||||||
OptionalAttr<I64ArrayAttr>:$strides);
|
OptionalAttr<I64ArrayAttr>:$strides);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||||
|
@ -245,7 +251,7 @@ def ONNXCastOp:ONNX_Op<"Cast",
|
||||||
"an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type."
|
"an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input,
|
||||||
DefaultValuedAttr<I64Attr, "0">:$to);
|
I64Attr:$to);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_output);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -297,7 +303,7 @@ def ONNXConcatOp:ONNX_Op<"Concat",
|
||||||
"Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on."
|
"Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$inputs,
|
let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$inputs,
|
||||||
DefaultValuedAttr<I64Attr, "0">:$axis);
|
I64Attr:$axis);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_concat_result);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_concat_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1621,7 +1627,7 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool",
|
||||||
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
|
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
|
||||||
DefaultValuedAttr<I64Attr, "0">:$ceil_mode,
|
DefaultValuedAttr<I64Attr, "0">:$ceil_mode,
|
||||||
OptionalAttr<I64ArrayAttr>:$dilations,
|
OptionalAttr<I64ArrayAttr>:$dilations,
|
||||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$kernel_shape,
|
I64ArrayAttr:$kernel_shape,
|
||||||
OptionalAttr<I64ArrayAttr>:$pads,
|
OptionalAttr<I64ArrayAttr>:$pads,
|
||||||
DefaultValuedAttr<I64Attr, "0">:$storage_order,
|
DefaultValuedAttr<I64Attr, "0">:$storage_order,
|
||||||
OptionalAttr<I64ArrayAttr>:$strides);
|
OptionalAttr<I64ArrayAttr>:$strides);
|
||||||
|
@ -2122,8 +2128,8 @@ def ONNXRNNOp:ONNX_Op<"RNN",
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens,
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h,
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h,
|
||||||
DefaultValuedAttr<F32ArrayAttr, "{}">:$activation_alpha,
|
OptionalAttr<F32ArrayAttr>:$activation_alpha,
|
||||||
DefaultValuedAttr<F32ArrayAttr, "{}">:$activation_beta,
|
OptionalAttr<F32ArrayAttr>:$activation_beta,
|
||||||
DefaultValuedAttr<StrArrayAttr, "{\"Tanh\", \"Tanh\"}">:$activations,
|
DefaultValuedAttr<StrArrayAttr, "{\"Tanh\", \"Tanh\"}">:$activations,
|
||||||
OptionalAttr<F32Attr>:$clip,
|
OptionalAttr<F32Attr>:$clip,
|
||||||
DefaultValuedAttr<StrAttr, "forward">:$direction,
|
DefaultValuedAttr<StrAttr, "forward">:$direction,
|
||||||
|
@ -3519,7 +3525,7 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
|
||||||
""
|
""
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
||||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$axes);
|
I64ArrayAttr:$axes);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_expanded);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_expanded);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue