[MLIR] generate op from onnx document (#366)

* generate op from onnx document

* Restore FullGemm

* update the op attribute for shape inference and canonicalizer

* Update onnx_canonicalization.mlir
This commit is contained in:
TONG CHEN 2019-11-18 21:08:21 -05:00 committed by Tian Jin
parent d01ac7732f
commit 3f68c5420d
9 changed files with 4143 additions and 93 deletions

View File

@ -1,5 +1,6 @@
add_library(builder
frontend_dialect_transformer.cpp
op_build_table.inc
)
target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR})

View File

@ -149,6 +149,33 @@ class FrontendGenImpl {
}
}
//if c++17 is used, these two def can be combined with 'if constexpr'
//leave n there for possible future use
//alternative is to use template and pass the outputTypes, inputs and attributes
//as parameter
#define MultipleOuts(name, nIn, nOut)\
{ \
if (nIn == inputs.size() && nOut == outputTypes.size()) {\
auto op = builder_.create<mlir::ONNX##name##Op>(UnknownLoc(), outputTypes, inputs, attributes); \
for (int i = 0; i < node.output().size(); i++) { \
frontend_symbols_.AddMapping(\
legalize_name(node.output()[i]), op.getResult(i));\
}\
return;\
}\
}
#define OneOut(name, nIn, nOut)\
{ \
if (nIn == inputs.size() && nOut == outputTypes.size()) {\
auto op = builder_.create<mlir::ONNX##name##Op>(UnknownLoc(), outputTypes, inputs, attributes); \
frontend_symbols_.AddMapping(\
legalize_name(node.output()[0]), op.getResult());\
return;\
}\
}
/*!
* Import an onnx input tensor type by determining and recording its type
* in a list of input tensor mlir types.
@ -206,38 +233,22 @@ class FrontendGenImpl {
}
}
// Handle ONNX Add Operation by using its representation in the
// ONNX Dialect.
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
std::vector<mlir::NamedAttribute> attributes;
llvm::StringRef OpName = node.op_type();
if (OpName == "Add") {
auto op = builder_.create<mlir::ONNXAddOp>(UnknownLoc(),
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
inputs[1]);
frontend_symbols_.AddMapping(
legalize_name(node.output()[0]), op.getResult());
return;
} else if (OpName == "MatMul") {
auto op = builder_.create<mlir::ONNXMatMulOp>(UnknownLoc(),
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
inputs[1]);
frontend_symbols_.AddMapping(
legalize_name(node.output()[0]), op.getResult());
return;
} else if (OpName == "Gemm") {
if (inputs.size() == 3) {
auto op = builder_.create<mlir::ONNXFullGemmOp>(UnknownLoc(),
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
inputs[1], inputs[2]);
frontend_symbols_.AddMapping(
legalize_name(node.output()[0]), op.getResult());
} else {
auto op = builder_.create<mlir::ONNXGemmOp>(UnknownLoc(),
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs);
frontend_symbols_.AddMapping(
legalize_name(node.output()[0]), op.getResult());
}
return;
}
//the following code is generated by gen_doc.py
//refer to dialect/onnx/onnx.td for details
//when the input or output of then op does not match the specification,
//the generic operator is used
//one known reeason is the optional input
#include "src/builder/op_build_table.inc"
// Old way of doing things.
mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());

View File

@ -0,0 +1,313 @@
if (OpName == "Abs") {
OneOut(Abs, 1, 1);
}else if (OpName == "Acos") {
OneOut(Acos, 1, 1);
}else if (OpName == "Acosh") {
OneOut(Acosh, 1, 1);
}else if (OpName == "Add") {
OneOut(Add, 2, 1);
}else if (OpName == "And") {
OneOut(And, 2, 1);
}else if (OpName == "ArgMax") {
OneOut(ArgMax, 1, 1);
}else if (OpName == "ArgMin") {
OneOut(ArgMin, 1, 1);
}else if (OpName == "Asin") {
OneOut(Asin, 1, 1);
}else if (OpName == "Asinh") {
OneOut(Asinh, 1, 1);
}else if (OpName == "Atan") {
OneOut(Atan, 1, 1);
}else if (OpName == "Atanh") {
OneOut(Atanh, 1, 1);
}else if (OpName == "AveragePool") {
OneOut(AveragePool, 1, 1);
}else if (OpName == "BatchNormalization") {
MultipleOuts(BatchNormalization, 5, 5);
}else if (OpName == "BitShift") {
OneOut(BitShift, 2, 1);
}else if (OpName == "Cast") {
OneOut(Cast, 1, 1);
}else if (OpName == "Ceil") {
OneOut(Ceil, 1, 1);
}else if (OpName == "Clip") {
OneOut(Clip, 3, 1);
}else if (OpName == "Compress") {
OneOut(Compress, 2, 1);
}else if (OpName == "Concat") {
OneOut(Concat, 1, 1);
}else if (OpName == "ConcatFromSequence") {
OneOut(ConcatFromSequence, 1, 1);
}else if (OpName == "Constant") {
OneOut(Constant, 0, 1);
}else if (OpName == "ConstantOfShape") {
OneOut(ConstantOfShape, 1, 1);
}else if (OpName == "Conv") {
OneOut(Conv, 3, 1);
}else if (OpName == "ConvInteger") {
OneOut(ConvInteger, 4, 1);
}else if (OpName == "ConvTranspose") {
OneOut(ConvTranspose, 3, 1);
}else if (OpName == "Cos") {
OneOut(Cos, 1, 1);
}else if (OpName == "Cosh") {
OneOut(Cosh, 1, 1);
}else if (OpName == "CumSum") {
OneOut(CumSum, 2, 1);
}else if (OpName == "DepthToSpace") {
OneOut(DepthToSpace, 1, 1);
}else if (OpName == "DequantizeLinear") {
OneOut(DequantizeLinear, 3, 1);
}else if (OpName == "Det") {
OneOut(Det, 1, 1);
}else if (OpName == "Div") {
OneOut(Div, 2, 1);
}else if (OpName == "Dropout") {
MultipleOuts(Dropout, 1, 2);
}else if (OpName == "DynamicQuantizeLinear") {
MultipleOuts(DynamicQuantizeLinear, 1, 3);
}else if (OpName == "Elu") {
OneOut(Elu, 1, 1);
}else if (OpName == "Equal") {
OneOut(Equal, 2, 1);
}else if (OpName == "Erf") {
OneOut(Erf, 1, 1);
}else if (OpName == "Exp") {
OneOut(Exp, 1, 1);
}else if (OpName == "Expand") {
OneOut(Expand, 2, 1);
}else if (OpName == "EyeLike") {
OneOut(EyeLike, 1, 1);
}else if (OpName == "Flatten") {
OneOut(Flatten, 1, 1);
}else if (OpName == "Floor") {
OneOut(Floor, 1, 1);
}else if (OpName == "GRU") {
MultipleOuts(GRU, 6, 2);
}else if (OpName == "Gather") {
OneOut(Gather, 2, 1);
}else if (OpName == "GatherElements") {
OneOut(GatherElements, 2, 1);
}else if (OpName == "GatherND") {
OneOut(GatherND, 2, 1);
}else if (OpName == "Gemm") {
OneOut(Gemm, 3, 1);
}else if (OpName == "GlobalAveragePool") {
OneOut(GlobalAveragePool, 1, 1);
}else if (OpName == "GlobalLpPool") {
OneOut(GlobalLpPool, 1, 1);
}else if (OpName == "GlobalMaxPool") {
OneOut(GlobalMaxPool, 1, 1);
}else if (OpName == "Greater") {
OneOut(Greater, 2, 1);
}else if (OpName == "HardSigmoid") {
OneOut(HardSigmoid, 1, 1);
}else if (OpName == "Hardmax") {
OneOut(Hardmax, 1, 1);
}else if (OpName == "Identity") {
OneOut(Identity, 1, 1);
}else if (OpName == "If") {
OneOut(If, 1, 1);
}else if (OpName == "InstanceNormalization") {
OneOut(InstanceNormalization, 3, 1);
}else if (OpName == "IsInf") {
OneOut(IsInf, 1, 1);
}else if (OpName == "IsNaN") {
OneOut(IsNaN, 1, 1);
}else if (OpName == "LRN") {
OneOut(LRN, 1, 1);
}else if (OpName == "LSTM") {
MultipleOuts(LSTM, 8, 3);
}else if (OpName == "LeakyRelu") {
OneOut(LeakyRelu, 1, 1);
}else if (OpName == "Less") {
OneOut(Less, 2, 1);
}else if (OpName == "Log") {
OneOut(Log, 1, 1);
}else if (OpName == "LogSoftmax") {
OneOut(LogSoftmax, 1, 1);
}else if (OpName == "Loop") {
OneOut(Loop, 3, 1);
}else if (OpName == "LpNormalization") {
OneOut(LpNormalization, 1, 1);
}else if (OpName == "LpPool") {
OneOut(LpPool, 1, 1);
}else if (OpName == "MatMul") {
OneOut(MatMul, 2, 1);
}else if (OpName == "MatMulInteger") {
OneOut(MatMulInteger, 4, 1);
}else if (OpName == "Max") {
OneOut(Max, 1, 1);
}else if (OpName == "MaxPool") {
MultipleOuts(MaxPool, 1, 2);
}else if (OpName == "MaxRoiPool") {
OneOut(MaxRoiPool, 2, 1);
}else if (OpName == "MaxUnpool") {
OneOut(MaxUnpool, 3, 1);
}else if (OpName == "Mean") {
OneOut(Mean, 1, 1);
}else if (OpName == "MeanVarianceNormalization") {
OneOut(MeanVarianceNormalization, 1, 1);
}else if (OpName == "Min") {
OneOut(Min, 1, 1);
}else if (OpName == "Mod") {
OneOut(Mod, 2, 1);
}else if (OpName == "Mul") {
OneOut(Mul, 2, 1);
}else if (OpName == "Multinomial") {
OneOut(Multinomial, 1, 1);
}else if (OpName == "Neg") {
OneOut(Neg, 1, 1);
}else if (OpName == "NonMaxSuppression") {
OneOut(NonMaxSuppression, 5, 1);
}else if (OpName == "NonZero") {
OneOut(NonZero, 1, 1);
}else if (OpName == "Not") {
OneOut(Not, 1, 1);
}else if (OpName == "OneHot") {
OneOut(OneHot, 3, 1);
}else if (OpName == "Or") {
OneOut(Or, 2, 1);
}else if (OpName == "PRelu") {
OneOut(PRelu, 2, 1);
}else if (OpName == "Pad") {
OneOut(Pad, 3, 1);
}else if (OpName == "Pow") {
OneOut(Pow, 2, 1);
}else if (OpName == "QLinearConv") {
OneOut(QLinearConv, 9, 1);
}else if (OpName == "QLinearMatMul") {
OneOut(QLinearMatMul, 8, 1);
}else if (OpName == "QuantizeLinear") {
OneOut(QuantizeLinear, 3, 1);
}else if (OpName == "RNN") {
MultipleOuts(RNN, 6, 2);
}else if (OpName == "RandomNormal") {
OneOut(RandomNormal, 0, 1);
}else if (OpName == "RandomNormalLike") {
OneOut(RandomNormalLike, 1, 1);
}else if (OpName == "RandomUniform") {
OneOut(RandomUniform, 0, 1);
}else if (OpName == "RandomUniformLike") {
OneOut(RandomUniformLike, 1, 1);
}else if (OpName == "Range") {
OneOut(Range, 3, 1);
}else if (OpName == "Reciprocal") {
OneOut(Reciprocal, 1, 1);
}else if (OpName == "ReduceL1") {
OneOut(ReduceL1, 1, 1);
}else if (OpName == "ReduceL2") {
OneOut(ReduceL2, 1, 1);
}else if (OpName == "ReduceLogSum") {
OneOut(ReduceLogSum, 1, 1);
}else if (OpName == "ReduceLogSumExp") {
OneOut(ReduceLogSumExp, 1, 1);
}else if (OpName == "ReduceMax") {
OneOut(ReduceMax, 1, 1);
}else if (OpName == "ReduceMean") {
OneOut(ReduceMean, 1, 1);
}else if (OpName == "ReduceMin") {
OneOut(ReduceMin, 1, 1);
}else if (OpName == "ReduceProd") {
OneOut(ReduceProd, 1, 1);
}else if (OpName == "ReduceSum") {
OneOut(ReduceSum, 1, 1);
}else if (OpName == "ReduceSumSquare") {
OneOut(ReduceSumSquare, 1, 1);
}else if (OpName == "Relu") {
OneOut(Relu, 1, 1);
}else if (OpName == "Reshape") {
OneOut(Reshape, 2, 1);
}else if (OpName == "Resize") {
OneOut(Resize, 4, 1);
}else if (OpName == "ReverseSequence") {
OneOut(ReverseSequence, 2, 1);
}else if (OpName == "RoiAlign") {
OneOut(RoiAlign, 3, 1);
}else if (OpName == "Round") {
OneOut(Round, 1, 1);
}else if (OpName == "Scan") {
OneOut(Scan, 1, 1);
}else if (OpName == "Scatter") {
OneOut(Scatter, 3, 1);
}else if (OpName == "ScatterElements") {
OneOut(ScatterElements, 3, 1);
}else if (OpName == "ScatterND") {
OneOut(ScatterND, 3, 1);
}else if (OpName == "Selu") {
OneOut(Selu, 1, 1);
}else if (OpName == "SequenceAt") {
OneOut(SequenceAt, 2, 1);
}else if (OpName == "SequenceConstruct") {
OneOut(SequenceConstruct, 1, 1);
}else if (OpName == "SequenceEmpty") {
OneOut(SequenceEmpty, 0, 1);
}else if (OpName == "SequenceErase") {
OneOut(SequenceErase, 2, 1);
}else if (OpName == "SequenceInsert") {
OneOut(SequenceInsert, 3, 1);
}else if (OpName == "SequenceLength") {
OneOut(SequenceLength, 1, 1);
}else if (OpName == "Shape") {
OneOut(Shape, 1, 1);
}else if (OpName == "Shrink") {
OneOut(Shrink, 1, 1);
}else if (OpName == "Sigmoid") {
OneOut(Sigmoid, 1, 1);
}else if (OpName == "Sign") {
OneOut(Sign, 1, 1);
}else if (OpName == "Sin") {
OneOut(Sin, 1, 1);
}else if (OpName == "Sinh") {
OneOut(Sinh, 1, 1);
}else if (OpName == "Size") {
OneOut(Size, 1, 1);
}else if (OpName == "Slice") {
OneOut(Slice, 5, 1);
}else if (OpName == "Softmax") {
OneOut(Softmax, 1, 1);
}else if (OpName == "Softplus") {
OneOut(Softplus, 1, 1);
}else if (OpName == "Softsign") {
OneOut(Softsign, 1, 1);
}else if (OpName == "SpaceToDepth") {
OneOut(SpaceToDepth, 1, 1);
}else if (OpName == "Split") {
OneOut(Split, 1, 1);
}else if (OpName == "SplitToSequence") {
OneOut(SplitToSequence, 2, 1);
}else if (OpName == "Sqrt") {
OneOut(Sqrt, 1, 1);
}else if (OpName == "Squeeze") {
OneOut(Squeeze, 1, 1);
}else if (OpName == "StringNormalizer") {
OneOut(StringNormalizer, 1, 1);
}else if (OpName == "Sub") {
OneOut(Sub, 2, 1);
}else if (OpName == "Sum") {
OneOut(Sum, 1, 1);
}else if (OpName == "Tan") {
OneOut(Tan, 1, 1);
}else if (OpName == "Tanh") {
OneOut(Tanh, 1, 1);
}else if (OpName == "TfIdfVectorizer") {
OneOut(TfIdfVectorizer, 1, 1);
}else if (OpName == "ThresholdedRelu") {
OneOut(ThresholdedRelu, 1, 1);
}else if (OpName == "Tile") {
OneOut(Tile, 2, 1);
}else if (OpName == "TopK") {
MultipleOuts(TopK, 2, 2);
}else if (OpName == "Transpose") {
OneOut(Transpose, 1, 1);
}else if (OpName == "Unique") {
MultipleOuts(Unique, 1, 4);
}else if (OpName == "Unsqueeze") {
OneOut(Unsqueeze, 1, 1);
}else if (OpName == "Upsample") {
OneOut(Upsample, 2, 1);
}else if (OpName == "Where") {
OneOut(Where, 3, 1);
}else if (OpName == "Xor") {
OneOut(Xor, 2, 1);
}

View File

@ -10,8 +10,9 @@ add_library(
dialect/krnl/parser_helper.hpp
pass/shape_inference_pass.cpp
pass/shape_inference_interface.hpp
pass/onnx_combine.cpp
pass/passes.hpp)
pass/passes.hpp
dialect/onnx/onnxop.inc
pass/onnx_combine.cpp)
# Include root src directory.
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})

View File

@ -0,0 +1,520 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from collections import defaultdict
import io
import os
import sys
import numpy as np # type: ignore
from onnx import defs, FunctionProto, helper, OperatorStatus
from onnx.defs import OpSchema, ONNX_DOMAIN, ONNX_ML_DOMAIN
from onnx.backend.test.case import collect_snippets
from onnx.backend.sample.ops import collect_sample_implementations
from typing import Any, Text, Sequence, Dict, List, Type, Set, Tuple
SNIPPETS = collect_snippets()
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
ONNX_ML = not bool(os.getenv('ONNX_ML') == '0')
ONNX_ML = False
print("ONNX_ML", ONNX_ML)
if ONNX_ML:
ext = '-ml.md'
else:
ext = '.md'
def display_number(v): # type: (int) -> Text
if defs.OpSchema.is_infinite(v):
return '&#8734;'
return Text(v)
def should_render_domain(domain): # type: (Text) -> bool
if domain == ONNX_ML_DOMAIN and not ONNX_ML:
return False
elif ONNX_ML and domain != ONNX_ML_DOMAIN:
return False
return True
def format_name_with_domain(domain, schema_name): # type: (Text, Text) -> Text
if domain:
return '{}.{}'.format(domain, schema_name)
else:
return schema_name
def display_attr_type(v): # type: (OpSchema.AttrType) -> Text
assert isinstance(v, OpSchema.AttrType)
s = Text(v)
s = s[s.rfind('.') + 1:].lower()
if s[-1] == 's':
s = 'list of ' + s
return s
def display_domain(domain): # type: (Text) -> Text
if domain:
return "the '{}' operator set".format(domain)
else:
return "the default ONNX operator set"
def display_domain_short(domain): # type: (Text) -> Text
if domain:
return domain
else:
return 'ai.onnx (default)'
def display_version_link(name, version): # type: (Text, int) -> Text
changelog_md = 'Changelog' + ext
name_with_ver = '{}-{}'.format(name, version)
return '<a href="{}#{}">{}</a>'.format(changelog_md, name_with_ver, name_with_ver)
def display_schema(schema, versions): # type: (OpSchema, Sequence[OpSchema]) -> Text
s = ''
# doc
if schema.doc:
s += '\n'
s += '\n'.join(' ' + line
for line in schema.doc.lstrip().splitlines())
s += '\n'
# since version
s += '\n#### Version\n'
if schema.support_level == OpSchema.SupportType.EXPERIMENTAL:
s += '\nNo versioning maintained for experimental ops.'
else:
s += '\nThis version of the operator has been ' + ('deprecated' if schema.deprecated else 'available') + ' since version {}'.format(schema.since_version)
s += ' of {}.\n'.format(display_domain(schema.domain))
if len(versions) > 1:
# TODO: link to the Changelog.md
s += '\nOther versions of this operator: {}\n'.format(
', '.join(display_version_link(format_name_with_domain(v.domain, v.name),
v.since_version) for v in versions[:-1]))
# If this schema is deprecated, don't display any of the following sections
if schema.deprecated:
return s
# attributes
if schema.attributes:
s += '\n#### Attributes\n\n'
s += '<dl>\n'
for _, attr in sorted(schema.attributes.items()):
# option holds either required or default value
opt = ''
if attr.required:
opt = 'required'
elif attr.default_value.name:
default_value = helper.get_attribute_value(attr.default_value)
def format_value(value): # type: (Any) -> Text
if isinstance(value, float):
formatted = str(np.round(value, 5))
# use default formatting, unless too long.
if (len(formatted) > 10):
formatted = str("({:e})".format(value))
return formatted
elif isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3:
return str(value.decode('utf-8'))
return str(value)
if isinstance(default_value, list):
default_value = [format_value(val) for val in default_value]
else:
default_value = format_value(default_value)
opt = 'default is {}'.format(default_value)
s += '<dt><tt>{}</tt> : {}{}</dt>\n'.format(
attr.name,
display_attr_type(attr.type),
' ({})'.format(opt) if opt else '')
s += '<dd>{}</dd>\n'.format(attr.description)
s += '</dl>\n'
# inputs
s += '\n#### Inputs'
if schema.min_input != schema.max_input:
s += ' ({} - {})'.format(display_number(schema.min_input),
display_number(schema.max_input))
s += '\n\n'
if schema.inputs:
s += '<dl>\n'
for input in schema.inputs:
option_str = ""
if OpSchema.FormalParameterOption.Optional == input.option:
option_str = " (optional)"
elif OpSchema.FormalParameterOption.Variadic == input.option:
if input.isHomogeneous:
option_str = " (variadic)"
else:
option_str = " (variadic, heterogeneous)"
s += '<dt><tt>{}</tt>{} : {}</dt>\n'.format(input.name, option_str, input.typeStr)
s += '<dd>{}</dd>\n'.format(input.description)
s += '</dl>\n'
# outputs
s += '\n#### Outputs'
if schema.min_output != schema.max_output:
s += ' ({} - {})'.format(display_number(schema.min_output),
display_number(schema.max_output))
s += '\n\n'
if schema.outputs:
s += '<dl>\n'
for output in schema.outputs:
option_str = ""
if OpSchema.FormalParameterOption.Optional == output.option:
option_str = " (optional)"
elif OpSchema.FormalParameterOption.Variadic == output.option:
if output.isHomogeneous:
option_str = " (variadic)"
else:
option_str = " (variadic, heterogeneous)"
s += '<dt><tt>{}</tt>{} : {}</dt>\n'.format(output.name, option_str, output.typeStr)
s += '<dd>{}</dd>\n'.format(output.description)
s += '</dl>\n'
# type constraints
s += '\n#### Type Constraints'
s += '\n\n'
if schema.type_constraints:
s += '<dl>\n'
for type_constraint in schema.type_constraints:
allowedTypes = type_constraint.allowed_type_strs
if (len(allowedTypes) > 0):
allowedTypeStr = allowedTypes[0]
for allowedType in allowedTypes[1:]:
allowedTypeStr += ', ' + allowedType
s += '<dt><tt>{}</tt> : {}</dt>\n'.format(
type_constraint.type_param_str, allowedTypeStr)
s += '<dd>{}</dd>\n'.format(type_constraint.description)
s += '</dl>\n'
# Function Body
if schema.has_function: # type: ignore
s += '\n#### Function\n'
s += '\nThe Function can be represented as a function.\n'
return s
def support_level_str(level): # type: (OpSchema.SupportType) -> Text
return \
"<sub>experimental</sub> " if level == OpSchema.SupportType.EXPERIMENTAL else ""
def convert_type(tstr) :
tfrom = np.array(['bool', 'int8', 'int16', 'int32', 'int64',
'unkown', 'float16', 'float', 'double'])
tto =np.array(['I1', 'I8', 'I16', 'I32', 'I64',
'BF16', 'F16', 'F32', 'F64'])
index = -1
for i in range(len(tfrom)) :
if tfrom[i] in tstr :
index = i
break
if index == -1 :
print("error", tstr)
return ''
else :
return tto[i]
def collect_types(schema, input) :
allowedTypeStr=''
#first step just ignore the type constraints
return allowedTypeStr
if input.typeStr :
tstr = input.typeStr
else :
return allwedTypeStr
if schema.type_constraints:
for type_constraint in schema.type_constraints:
if type_constraint.type_param_str != tstr :
continue
allowedTypes = type_constraint.allowed_type_strs
allowedTypeStr=''
if (len(allowedTypes) > 0):
t = convert_type(allowedTypes[0])
if t == '' :
return ''
allowedTypeStr += t
for allowedType in allowedTypes[1:]:
t = convert_type(allowedType)
if t == '' :
return ''
if not t in allowedTypeStr :
allowedTypeStr += ', '+t
return allowedTypeStr
return allowedTypeStr
def gen_schema(schema) :
ShapeInferenceList=['Add', 'MatMul', 'Gemm']
CanonicalList=['Add']
line_indent = ' '
#s = 'def ONNX'+schema.name+str(schema.since_version)+'Op:ONNX_Op<"'+schema.name+'", \n'
s = 'def ONNX'+schema.name+'Op:ONNX_Op<"'+schema.name+'", \n'
s += line_indent+' [NoSideEffect'
if schema.name in ShapeInferenceList :
s+= ', DeclareOpInterfaceMethods<ShapeInferenceOpInterface>'
s += ']> {'
if schema.name in CanonicalList:
s += '\n'+line_indent+'let hasCanonicalizer = 1;'
#summary
s += '\n'+line_indent
s += 'let summary = "ONNX '+schema.name+' operation";'
#description
s += '\n'+line_indent
s += 'let description = [{'
if schema.doc:
"""
s += '\n'.join(line_indent + line
for line in schema.doc.lstrip().splitlines())
"""
for line in schema.doc.lstrip().splitlines():
line = line.replace('}]', '\}\]')
s += '\n'+line_indent+' '+'"'+line+'"'
else :
s += '\n'+line_indent*2 +'no doc for this op from onnx'
s += '\n'+line_indent+'}];'
#input
s+= '\n'+line_indent+'let arguments = (ins '
if schema.inputs:
for input in schema.inputs:
if input != schema.inputs[0] :
s+= ', '
etypes=collect_types(schema, input)
if OpSchema.FormalParameterOption.Optional == input.option:
#TODO: handle optional
print("optional ", input.name)
elif OpSchema.FormalParameterOption.Variadic == input.option:
if input.isHomogeneous:
s+= 'Variadic<'
else:
#TODO handle (variadic, heterogeneous)"
print('variadic, heterogeneous', input.name)
if etypes == '':
s+= 'AnyTensor'
else:
s+= 'TensorOf<['+etypes+']>'
if OpSchema.FormalParameterOption.Optional == input.option:
#TODO: handle optional
t=''
elif OpSchema.FormalParameterOption.Variadic == input.option:
if input.isHomogeneous:
s+= '>'
else:
#TODO handle (variadic, heterogeneous)"
t=''
s+=':$'+input.name
s+= ');'
#output
s+= '\n'+line_indent+'let results = (outs '
if schema.outputs:
for output in schema.outputs:
if output != schema.outputs[0] :
s+= ', '
#need to interpret output.typeStr
etypes=collect_types(schema, output)
if etypes == '':
s+= 'AnyTensor'
else:
s+= 'TensorOf<['+etypes+']>'
s+= ');'
#s+= 'let hasCanonicalizer = 1;'
s += '\n}\n\n'
return s
def main(args): # type: (Type[Args]) -> None
with io.open(args.changelog, 'w', newline='') as fout:
fout.write('## Operator Changelog\n')
fout.write(
"*This file is automatically generated from the\n"
" [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
" Do not modify directly and instead edit operator definitions.*\n")
# domain -> version -> [schema]
dv_index = defaultdict(lambda: defaultdict(list)) # type: Dict[Text, Dict[int, List[OpSchema]]]
for schema in defs.get_all_schemas_with_history():
dv_index[schema.domain][schema.since_version].append(schema)
fout.write('\n')
for domain, versionmap in sorted(dv_index.items()):
print("domain", domain)
if not should_render_domain(domain):
continue
s = '# {}\n'.format(display_domain_short(domain))
for version, unsorted_schemas in sorted(versionmap.items()):
s += '## Version {} of {}\n'.format(version, display_domain(domain))
for schema in sorted(unsorted_schemas, key=lambda s: s.name):
name_with_ver = '{}-{}'.format(format_name_with_domain(domain, schema.name),
schema.since_version)
s += ('### <a name="{}"></a>**{}**' + (' (deprecated)' if schema.deprecated else '') + '</a>\n').format(name_with_ver, name_with_ver)
s += display_schema(schema, [schema])
s += '\n'
fout.write(s)
with io.open(args.output, 'w', newline='', encoding="utf-8") as fout:
fout.write('## Operator Schemas\n')
fout.write(
"*This file is automatically generated from the\n"
" [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
" Do not modify directly and instead edit operator definitions.*\n")
# domain -> support level -> name -> [schema]
index = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
for schema in defs.get_all_schemas_with_history():
#print("check point 0", schema.name, schema.domain, schema.support_level)
#gen_schema(schema)
index[schema.domain][int(schema.support_level)][schema.name].append(schema)
fout.write('\n')
# Preprocess the Operator Schemas
# [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
operator_schemas = list() # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
exsting_ops = set() # type: Set[Text]
for domain, _supportmap in sorted(index.items()):
if not should_render_domain(domain):
continue
processed_supportmap = list()
for _support, _namemap in sorted(_supportmap.items()):
processed_namemap = list()
for n, unsorted_versions in sorted(_namemap.items()):
versions = sorted(unsorted_versions, key=lambda s: s.since_version)
schema = versions[-1]
#print("check point 2", schema)
if schema.name in exsting_ops:
continue
exsting_ops.add(schema.name)
processed_namemap.append((n, schema, versions))
processed_supportmap.append((_support, processed_namemap))
operator_schemas.append((domain, processed_supportmap))
# Table of contents
for domain, supportmap in operator_schemas:
s = '* {}\n'.format(display_domain_short(domain))
fout.write(s)
function_ops = list()
for _, namemap in supportmap:
for n, schema, versions in namemap:
if schema.has_function: # type: ignore
function_ops.append((n, schema, versions))
continue
s = ' * {}<a href="#{}">{}</a>\n'.format(
support_level_str(schema.support_level),
format_name_with_domain(domain, n),
format_name_with_domain(domain, n))
fout.write(s)
if len(function_ops):
fout.write('\n')
fout.write(' **Operators with function registered:**\n')
for n, schema, versions in function_ops:
s = ' * {}<a href="#{}">{}</a>\n'.format(
support_level_str(schema.support_level),
format_name_with_domain(domain, n),
format_name_with_domain(domain, n))
fout.write(s)
fout.write('\n')
tdfile= io.open(args.tdfile, 'w', newline='')
fefile=io.open('op_build_table.inc', 'w', newline='')
firstfunc = True
for domain, supportmap in operator_schemas:
s = '## {}\n'.format(display_domain_short(domain))
fout.write(s)
for _, namemap in supportmap:
for op_type, schema, versions in namemap:
# op_type
#print("check point 1", schema.name, len(schema.inputs), len(schema.outputs))
if firstfunc :
fefile.write(' '+'if (OpName == "'+schema.name+'") {\n')
firstfunc = False
else :
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
if len(schema.outputs) > 1 :
fefile.write(' '+'MultipleOuts('+schema.name+', '
+str(schema.since_version)+', '
+str(len(schema.inputs))+', '
+str(len(schema.outputs))+');\n')
else :
fefile.write(' '+'OneOut('+schema.name+', '
+str(schema.since_version)+', '
+str(len(schema.inputs))+', '
+str(len(schema.outputs))+');\n')
r = gen_schema(schema)
tdfile.write(r)
s = ('### {}<a name="{}"></a><a name="{}">**{}**' + (' (deprecated)' if schema.deprecated else '') + '</a>\n').format(
support_level_str(schema.support_level),
format_name_with_domain(domain, op_type),
format_name_with_domain(domain, op_type.lower()),
format_name_with_domain(domain, op_type))
s += display_schema(schema, versions)
s += '\n\n'
if op_type in SNIPPETS:
s += '#### Examples\n\n'
for summary, code in sorted(SNIPPETS[op_type]):
s += '<details>\n'
s += '<summary>{}</summary>\n\n'.format(summary)
s += '```python\n{}\n```\n\n'.format(code)
s += '</details>\n'
s += '\n\n'
if op_type.lower() in SAMPLE_IMPLEMENTATIONS:
s += '#### Sample Implementation\n\n'
s += '<details>\n'
s += '<summary>{}</summary>\n\n'.format(op_type)
s += '```python\n{}\n```\n\n'.format(SAMPLE_IMPLEMENTATIONS[op_type.lower()])
s += '</details>\n'
s += '\n\n'
fout.write(s)
fefile.write(' }')
fefile.close()
if __name__ == '__main__':
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
docs_dir = os.path.join(base_dir, 'docs')
print(docs_dir)
class Args(object):
output = os.path.join(docs_dir, 'Operators' + ext)
changelog = os.path.join(docs_dir, 'Changelog' + ext)
tdfile = os.path.join(docs_dir, 'onnxop.inc')
print(Args)
main(Args)

View File

@ -39,54 +39,24 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
// ONNX Operations
//===----------------------------------------------------------------------===//
// We define an ONNX operation for adding two tensors elementwise.
def ONNXAddOp: ONNX_Op<"add",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX add operation";
let description = [{
//the tablegen code onnxop.in is generated with gen_doc.py
//clone and install onnx
// git clone --recursive https://github.com/onnx/onnx.git
// set up env for anaconda3 and for DLC (BOOSTROOT, cmake, gcc ...)
// cd onnx
//install onnx
// CC=gcc CXX=g++ pip install -e .
//run the script
// python onnx/defs/gen_doc.py
//result is in docs/onnxop.inc
//current limitations:
// 1. Attributes are not processed
// 2. output type inference not implemented except Add
// 3. Type Attribute: 'optional' and 'Variadic hetergeneous' are ignored
// 4. type of string, complex64 and complex128 for input/output are ignored
// 5. unsigned int are treated as signed one
The "onnx.add" adds two tensors element-wise.
}];
// TODO: AnyTensor might be too wide for ONNX and may need to be constrained
// to fewer valid types.
// In the ONNX spec:
// T : tensor(uint32), tensor(uint64),
// tensor(int32), tensor(int64),
// tensor(float16), tensor(float), tensor(double)
//
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in);
let results = (outs AnyTensor);
let hasCanonicalizer = 1;
}
def ONNXMatMulOp: ONNX_Op<"matmul",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX matrix multiply operation";
let description = [{
The "onnx.mul" multiplies two matrices.
}];
let arguments = (ins AnyTypeOf<[F32Tensor, F64Tensor]>:$lhs_in,
AnyTypeOf<[F32Tensor, F64Tensor]>:$rhs_in);
let results = (outs AnyTypeOf<[F32Tensor, F64Tensor]>);
}
def ONNXGemmOp: ONNX_Op<"gemm",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation";
let description = [{
The "onnx.gemm" generic matrix multiplication with bias.
}];
let arguments = (ins Variadic<AnyTensor>:$inputs);
let results = (outs AnyTensor);
}
include "dialect/onnx/onnxop.inc"
def ONNXFullGemmOp: ONNX_Op<"full_gemm",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {

View File

@ -39,9 +39,9 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
//===----------------------------------------------------------------------===//
// ONNX Operations
//===----------------------------------------------------------------------===//
// Add
/// Infer the output shape of the ONNXAddOp. This method is required by the
/// shape inference interface.
void ONNXAddOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}

File diff suppressed because it is too large Load Diff

View File

@ -7,8 +7,8 @@ module {
%1 = "frontend.input t2"() : () -> tensor<10x10xf32>
%2 = "frontend.input t3"() : () -> tensor<10x10xf32>
// CHECK: %{{[0-9]+}} = "onnx.full_gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%3 = "onnx.matmul"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%4 = "onnx.add"(%3, %2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%3 = "onnx.MatMul"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%4 = "onnx.Add"(%3, %2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%5 = "frontend.output t4"(%4) : (tensor<10x10xf32>) -> tensor<10x10xf32>
}
}