2019-11-19 10:08:21 +08:00
|
|
|
#!/usr/bin/env python
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
from collections import defaultdict, OrderedDict
|
2020-04-08 15:00:34 +08:00
|
|
|
from io import StringIO
|
2019-11-19 10:08:21 +08:00
|
|
|
import io
|
|
|
|
import os
|
|
|
|
import sys
|
2020-02-24 23:46:48 +08:00
|
|
|
import datetime
|
2020-04-08 15:00:34 +08:00
|
|
|
import argparse
|
2019-11-19 10:08:21 +08:00
|
|
|
|
|
|
|
import numpy as np # type: ignore
|
|
|
|
|
|
|
|
from onnx import defs, FunctionProto, helper, OperatorStatus
|
|
|
|
from onnx.defs import OpSchema, ONNX_DOMAIN, ONNX_ML_DOMAIN
|
|
|
|
from onnx.backend.test.case import collect_snippets
|
|
|
|
from onnx.backend.sample.ops import collect_sample_implementations
|
|
|
|
from typing import Any, Text, Sequence, Dict, List, Type, Set, Tuple
|
|
|
|
|
2020-05-21 01:48:45 +08:00
|
|
|
import pprint
|
|
|
|
|
2020-04-08 15:00:34 +08:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--dry-run-onnx-ops",
|
|
|
|
help="Output ONNXOps.td.inc content to stdout.",
|
|
|
|
action="store_true",
|
|
|
|
default=False)
|
|
|
|
parser.add_argument("--dry-run-op-build-table",
|
|
|
|
help="Output OpBuildTable.inc content to stdout.",
|
|
|
|
action="store_true",
|
|
|
|
default=False)
|
2020-05-21 01:48:45 +08:00
|
|
|
parser.add_argument("--check-operation-version",
|
|
|
|
help="check whether the imported onnx package has new operation or "
|
|
|
|
" newer version of operation compared with version stored in version_dicts",
|
|
|
|
action="store_true",
|
|
|
|
default=False)
|
2020-05-13 06:43:44 +08:00
|
|
|
|
2020-04-08 15:00:34 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
2020-05-21 01:48:45 +08:00
|
|
|
check_operation_version = args.check_operation_version
|
|
|
|
|
|
|
|
|
|
|
|
# Record the version of each operation that is treated as the current version.
|
|
|
|
# To check whether the onnx package being used has newer version operation,
|
|
|
|
# run this script with --check-operation-version flag.
|
|
|
|
# Update this dictionary when a newer version is implemented
|
|
|
|
# TODO: how to keep the old version
|
2020-06-23 08:01:56 +08:00
|
|
|
version_dict = {'Abs': 6,
|
2020-05-21 01:48:45 +08:00
|
|
|
'Acos': 7,
|
|
|
|
'Acosh': 9,
|
|
|
|
'Add': 7,
|
|
|
|
'And': 7,
|
|
|
|
'ArgMax': 11,
|
|
|
|
'ArgMin': 11,
|
|
|
|
'Asin': 7,
|
|
|
|
'Asinh': 9,
|
|
|
|
'Atan': 7,
|
|
|
|
'Atanh': 9,
|
|
|
|
'AveragePool': 11,
|
|
|
|
'BatchNormalization': 9,
|
|
|
|
'BitShift': 11,
|
|
|
|
'Cast': 9,
|
|
|
|
'Ceil': 6,
|
|
|
|
'Clip': 11,
|
|
|
|
'Compress': 11,
|
|
|
|
'Concat': 11,
|
|
|
|
'ConcatFromSequence': 11,
|
|
|
|
'Constant': 11,
|
|
|
|
'ConstantOfShape': 9,
|
|
|
|
'Conv': 11,
|
|
|
|
'ConvInteger': 10,
|
|
|
|
'ConvTranspose': 11,
|
|
|
|
'Cos': 7,
|
|
|
|
'Cosh': 9,
|
|
|
|
'CumSum': 11,
|
|
|
|
'DepthToSpace': 11,
|
|
|
|
'DequantizeLinear': 10,
|
|
|
|
'Det': 11,
|
|
|
|
'Div': 7,
|
|
|
|
'Dropout': 10,
|
|
|
|
'DynamicQuantizeLinear': 11,
|
|
|
|
'Elu': 6,
|
|
|
|
'Equal': 11,
|
|
|
|
'Erf': 9,
|
|
|
|
'Exp': 6,
|
|
|
|
'Expand': 8,
|
|
|
|
'EyeLike': 9,
|
|
|
|
'Flatten': 11,
|
|
|
|
'Floor': 6,
|
|
|
|
'GRU': 7,
|
|
|
|
'Gather': 11,
|
|
|
|
'GatherElements': 11,
|
|
|
|
'GatherND': 11,
|
|
|
|
'Gemm': 11,
|
|
|
|
'GlobalAveragePool': 1,
|
|
|
|
'GlobalLpPool': 2,
|
|
|
|
'GlobalMaxPool': 1,
|
|
|
|
'Greater': 9,
|
|
|
|
'HardSigmoid': 6,
|
|
|
|
'Hardmax': 11,
|
|
|
|
'Identity': 1,
|
|
|
|
'If': 11,
|
|
|
|
'InstanceNormalization': 6,
|
|
|
|
'IsInf': 10,
|
|
|
|
'IsNaN': 9,
|
|
|
|
'LRN': 1,
|
|
|
|
'LSTM': 7,
|
|
|
|
'LeakyRelu': 6,
|
|
|
|
'Less': 9,
|
|
|
|
'Log': 6,
|
|
|
|
'LogSoftmax': 11,
|
|
|
|
'Loop': 11,
|
|
|
|
'LpNormalization': 1,
|
|
|
|
'LpPool': 11,
|
|
|
|
'MatMul': 9,
|
|
|
|
'MatMulInteger': 10,
|
|
|
|
'Max': 8,
|
|
|
|
'MaxPool': 11,
|
|
|
|
'MaxRoiPool': 1,
|
|
|
|
'MaxUnpool': 11,
|
|
|
|
'Mean': 8,
|
|
|
|
'MeanVarianceNormalization': 9,
|
|
|
|
'Min': 8,
|
|
|
|
'Mod': 10,
|
|
|
|
'Mul': 7,
|
|
|
|
'Multinomial': 7,
|
|
|
|
'Neg': 6,
|
|
|
|
'NonMaxSuppression': 11,
|
|
|
|
'NonZero': 9,
|
|
|
|
'Not': 1,
|
|
|
|
'OneHot': 11,
|
|
|
|
'Or': 7,
|
|
|
|
'PRelu': 9,
|
|
|
|
'Pad': 11,
|
|
|
|
'Pow': 7,
|
|
|
|
'QLinearConv': 10,
|
|
|
|
'QLinearMatMul': 10,
|
|
|
|
'QuantizeLinear': 10,
|
|
|
|
'RNN': 7,
|
|
|
|
'RandomNormal': 1,
|
|
|
|
'RandomNormalLike': 1,
|
|
|
|
'RandomUniform': 1,
|
|
|
|
'RandomUniformLike': 1,
|
|
|
|
'Range': 11,
|
|
|
|
'Reciprocal': 6,
|
|
|
|
'ReduceL1': 11,
|
|
|
|
'ReduceL2': 11,
|
|
|
|
'ReduceLogSum': 11,
|
|
|
|
'ReduceLogSumExp': 11,
|
|
|
|
'ReduceMax': 11,
|
|
|
|
'ReduceMean': 11,
|
|
|
|
'ReduceMin': 11,
|
|
|
|
'ReduceProd': 11,
|
|
|
|
'ReduceSum': 11,
|
|
|
|
'ReduceSumSquare': 11,
|
|
|
|
'Relu': 6,
|
|
|
|
'Reshape': 5,
|
|
|
|
'Resize': 11,
|
|
|
|
'ReverseSequence': 10,
|
|
|
|
'RoiAlign': 10,
|
|
|
|
'Round': 11,
|
|
|
|
'Scan': 11,
|
|
|
|
'Scatter': 11,
|
|
|
|
'ScatterElements': 11,
|
|
|
|
'ScatterND': 11,
|
|
|
|
'Selu': 6,
|
|
|
|
'SequenceAt': 11,
|
|
|
|
'SequenceConstruct': 11,
|
|
|
|
'SequenceEmpty': 11,
|
|
|
|
'SequenceErase': 11,
|
|
|
|
'SequenceInsert': 11,
|
|
|
|
'SequenceLength': 11,
|
|
|
|
'Shape': 1,
|
|
|
|
'Shrink': 9,
|
|
|
|
'Sigmoid': 6,
|
|
|
|
'Sign': 9,
|
|
|
|
'Sin': 7,
|
|
|
|
'Sinh': 9,
|
|
|
|
'Size': 1,
|
|
|
|
'Slice': 11,
|
|
|
|
'Softmax': 11,
|
|
|
|
'Softplus': 1,
|
|
|
|
'Softsign': 1,
|
|
|
|
'SpaceToDepth': 1,
|
|
|
|
'Split': 11,
|
|
|
|
'SplitToSequence': 11,
|
|
|
|
'Sqrt': 6,
|
|
|
|
'Squeeze': 11,
|
|
|
|
'StringNormalizer': 10,
|
|
|
|
'Sub': 7,
|
|
|
|
'Sum': 8,
|
|
|
|
'Tan': 7,
|
|
|
|
'Tanh': 6,
|
|
|
|
'TfIdfVectorizer': 9,
|
|
|
|
'ThresholdedRelu': 10,
|
|
|
|
'Tile': 6,
|
|
|
|
'TopK': 11,
|
|
|
|
'Transpose': 1,
|
|
|
|
'Unique': 11,
|
|
|
|
'Unsqueeze': 11,
|
|
|
|
'Upsample': 10,
|
|
|
|
'Where': 9,
|
2020-06-23 08:01:56 +08:00
|
|
|
'Xor': 7,
|
|
|
|
'ArrayFeatureExtractor': 1,
|
2020-05-21 01:48:45 +08:00
|
|
|
'Binarizer': 1,
|
|
|
|
'CastMap': 1,
|
|
|
|
'CategoryMapper': 1,
|
|
|
|
'DictVectorizer': 1,
|
|
|
|
'FeatureVectorizer': 1,
|
|
|
|
'Imputer': 1,
|
|
|
|
'LabelEncoder': 2,
|
|
|
|
'LinearClassifier': 1,
|
|
|
|
'LinearRegressor': 1,
|
|
|
|
'Normalizer': 1,
|
|
|
|
'OneHotEncoder': 1,
|
|
|
|
'SVMClassifier': 1,
|
|
|
|
'SVMRegressor': 1,
|
|
|
|
'Scaler': 1,
|
|
|
|
'TreeEnsembleClassifier': 1,
|
|
|
|
'TreeEnsembleRegressor': 1,
|
|
|
|
'ZipMap': 1}
|
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
# Manual specification of attribute defaults.
|
2020-01-30 02:54:46 +08:00
|
|
|
special_attr_defaults = dict([
|
2020-02-24 23:46:48 +08:00
|
|
|
# ("AveragePool.kernel_shape", ('ints', '{}')),
|
|
|
|
# ("MaxPool.kernel_shape", ('ints', '{}')),
|
|
|
|
# ("Cast.to", ('int', '0')),
|
|
|
|
# ("Concat.axis", ('int', '0')),
|
|
|
|
# ("Conv.group", ('int', '1')),
|
|
|
|
# ("Unsqueeze.axes", ('ints', '{}')),
|
|
|
|
# ("RNN.activation_alpha", ('floats', '{}')),
|
|
|
|
# ("RNN.activation_beta", ('floats', '{}')),
|
|
|
|
])
|
|
|
|
|
|
|
|
# Special operation importing handlers.
|
2020-01-30 02:54:46 +08:00
|
|
|
special_op_handler = dict([
|
2020-02-24 23:46:48 +08:00
|
|
|
("MaxPool", "ImportNodeMaxPool"),
|
|
|
|
("BatchNormalization", "ImportNodeBatchNormalization"),
|
|
|
|
("Pad", "ImportNodePad"),
|
2020-03-11 02:46:35 +08:00
|
|
|
("Reshape", "ImportNodeReshape"),
|
2020-02-24 23:46:48 +08:00
|
|
|
#("Transpose", "ImportNodeTranspose")
|
|
|
|
])
|
|
|
|
|
|
|
|
# Operations supporting shape inference.
|
|
|
|
OpsWithShapeInference = [
|
2020-06-09 14:55:49 +08:00
|
|
|
'Exp', 'Atan', 'Tan', 'Tanh', 'Sin', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
|
|
|
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul',
|
|
|
|
'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
2020-02-24 23:46:48 +08:00
|
|
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
|
|
|
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
2020-05-13 21:08:06 +08:00
|
|
|
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
2020-06-09 14:55:49 +08:00
|
|
|
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
|
|
|
|
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
|
2020-02-24 23:46:48 +08:00
|
|
|
]
|
|
|
|
|
|
|
|
# Operations supporting canonicalization.
|
2020-03-26 23:03:19 +08:00
|
|
|
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
|
2020-03-19 15:03:37 +08:00
|
|
|
|
|
|
|
# Operations who have operands that, if produced by constant operations, should
|
|
|
|
# be promoted to become an attribute (via attribute promotion).
|
|
|
|
#
|
|
|
|
# For each operation, a key/value pair is used to specify how attribute promotion
|
|
|
|
# should proceed. The key is the operation's name and the value is a list of
|
|
|
|
# tuples, whose first item is the attribute/operand name, and the second item is
|
|
|
|
# the index at which such operand occurs in the list of the operation's inputs.
|
2020-05-15 13:19:28 +08:00
|
|
|
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)],
|
|
|
|
"Pad": [("pads", 1), ("constant_value", 2)]}
|
2020-02-24 23:46:48 +08:00
|
|
|
|
2020-05-26 09:54:19 +08:00
|
|
|
# Interface for special handling of type inference
|
|
|
|
# The common code are put into get_type_inference_func
|
|
|
|
OpsWithResultTypeInference = {
|
|
|
|
"Constant":
|
|
|
|
'''if (auto attr = valueAttr()) {
|
|
|
|
resultTypes.push_back(attr.getType());
|
|
|
|
} else if (auto attr = sparse_valueAttr()) {
|
|
|
|
resultTypes.push_back(attr.getType());
|
2020-06-04 21:05:04 +08:00
|
|
|
}''',
|
|
|
|
"Cast":
|
|
|
|
'''auto toAttr = to().getSExtValue();
|
|
|
|
auto builder = mlir::OpBuilder(getContext());
|
|
|
|
resultTypes.push_back(mlir::UnrankedTensorType::get(
|
|
|
|
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));'''
|
2020-05-26 09:54:19 +08:00
|
|
|
}
|
2020-06-09 14:55:49 +08:00
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
# Add an Op in this list if the Op needs result type deduction which is required
|
|
|
|
# when writing declarative rewriting rules. Deduced type is always
|
|
|
|
# an UnrankedTensorType whose element type is the same as the first operand's
|
|
|
|
# element type.
|
|
|
|
#
|
|
|
|
# Currenlty, there are only two build methods generated:
|
|
|
|
# - one with operands and attributes having a separate parameter, and
|
|
|
|
# - one with operands and attributes having aggregated parameters.
|
2020-06-09 03:45:32 +08:00
|
|
|
custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad']
|
|
|
|
# Custom builder op list for operations with broadcast; we can deduce the right
|
|
|
|
# output type, no need to leave it undef as in the above list.
|
|
|
|
# Ops must have two operands, not one, not three... And there shall be two.
|
|
|
|
# TODO: handle variadic ops omitted here: Max, Min, Min, Sum.
|
|
|
|
custom_builder_broadcast_ops_list = ['Add', 'And', 'Div', 'Equal', 'Greater',
|
|
|
|
'Less', 'Mul', 'Or', 'Pow', 'Sub', 'Xor']
|
|
|
|
# union of both
|
|
|
|
custom_builder_ops_list = custom_builder_unranked_ops_list + custom_builder_broadcast_ops_list
|
2020-05-15 13:19:28 +08:00
|
|
|
|
|
|
|
#a dictionary to add any special definition for an operation
|
2020-06-09 14:55:49 +08:00
|
|
|
custom_definition_misc = dict([ ('Constant',
|
2020-05-15 13:19:28 +08:00
|
|
|
''' let builders = [
|
2020-05-20 15:45:42 +08:00
|
|
|
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{
|
2020-05-15 13:19:28 +08:00
|
|
|
if (value) {
|
|
|
|
auto tensorType = value.getType();
|
|
|
|
build(builder, state, tensorType, sparse_value, value);
|
|
|
|
} else {
|
|
|
|
auto tensorType = sparse_value.getType();
|
|
|
|
build(builder, state, tensorType, sparse_value, value);
|
|
|
|
}
|
|
|
|
}]>
|
|
|
|
];'''
|
|
|
|
)])
|
|
|
|
|
2020-02-21 22:28:24 +08:00
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
onnx_types = (
|
|
|
|
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
|
|
|
|
'float', 'double', 'complex64', 'complex128'
|
|
|
|
)
|
2020-06-09 14:55:49 +08:00
|
|
|
tblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64',
|
2020-05-22 10:03:16 +08:00
|
|
|
'Complex<F32>', 'Complex<F64>'
|
|
|
|
)
|
|
|
|
|
|
|
|
MAX_NUM_TYPES=20
|
|
|
|
|
2019-11-19 10:08:21 +08:00
|
|
|
SNIPPETS = collect_snippets()
|
|
|
|
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
|
|
|
|
|
|
|
|
def should_render_domain(domain): # type: (Text) -> bool
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def display_attr_type(v): # type: (OpSchema.AttrType) -> Text
|
|
|
|
assert isinstance(v, OpSchema.AttrType)
|
|
|
|
s = Text(v)
|
|
|
|
s = s[s.rfind('.') + 1:].lower()
|
|
|
|
if s[-1] == 's':
|
|
|
|
s = 'list of ' + s
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
2020-02-08 00:10:35 +08:00
|
|
|
def get_unique_output_name(schema, name):
|
2020-02-24 23:46:48 +08:00
|
|
|
for input in schema.inputs:
|
|
|
|
if input.name == name:
|
|
|
|
return 'out_' + name
|
2020-02-08 00:10:35 +08:00
|
|
|
return name
|
2019-11-19 10:08:21 +08:00
|
|
|
|
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
def onnx_attr_type_to_mlir_attr_type(t):
|
|
|
|
onnx_attr_type = Text(t)
|
|
|
|
onnx_attr_type = onnx_attr_type[onnx_attr_type.rfind('.') + 1:].lower()
|
|
|
|
|
|
|
|
if onnx_attr_type == 'int':
|
|
|
|
mlir_attr_type = 'I64Attr'
|
|
|
|
elif onnx_attr_type == 'float':
|
|
|
|
mlir_attr_type = 'F32Attr'
|
|
|
|
elif onnx_attr_type == 'ints':
|
|
|
|
mlir_attr_type = 'I64ArrayAttr'
|
|
|
|
elif onnx_attr_type == 'floats':
|
|
|
|
mlir_attr_type = 'F32ArrayAttr'
|
|
|
|
elif onnx_attr_type == "string":
|
|
|
|
mlir_attr_type = 'StrAttr'
|
|
|
|
elif onnx_attr_type == "strings":
|
|
|
|
mlir_attr_type = 'StrArrayAttr'
|
2019-11-19 10:08:21 +08:00
|
|
|
else:
|
2020-02-24 23:46:48 +08:00
|
|
|
mlir_attr_type = 'AnyAttr'
|
|
|
|
#TODO: tensor and sparse tensor
|
|
|
|
return mlir_attr_type
|
|
|
|
|
|
|
|
|
|
|
|
#TODO: any better way to do this.
|
|
|
|
def tblgen_attr_type_to_cpp_type(t):
|
|
|
|
if 'I64Attr' in t:
|
|
|
|
cpp_type = 'IntegerAttr'
|
|
|
|
elif 'F32Attr' in t:
|
|
|
|
cpp_type = 'FloatAttr'
|
|
|
|
elif 'I64ArrayAttr' in t or 'F32ArrayAttr' in t:
|
|
|
|
cpp_type = 'ArrayAttr'
|
|
|
|
elif 'StrAttr' in t:
|
|
|
|
cpp_type = 'StringAttr'
|
|
|
|
elif 'strings' in t:
|
|
|
|
cpp_type = 'ArrayAttr'
|
|
|
|
else:
|
|
|
|
cpp_type = 'Attribute'
|
|
|
|
return cpp_type
|
2019-11-19 10:08:21 +08:00
|
|
|
|
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
def tblgen_operand_type_to_cpp_type(op_type):
|
|
|
|
if op_type.startswith('Variadic'):
|
|
|
|
mytype = 'ValueRange'
|
|
|
|
else:
|
|
|
|
mytype = 'Value'
|
|
|
|
return mytype
|
2019-11-19 10:08:21 +08:00
|
|
|
|
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
def np_type_to_tblgen_attr_type(tstr):
|
2019-11-19 10:08:21 +08:00
|
|
|
index = -1
|
2020-05-22 10:03:16 +08:00
|
|
|
for i in range(len(onnx_types)):
|
|
|
|
if onnx_types[i] in tstr:
|
2019-11-19 10:08:21 +08:00
|
|
|
index = i
|
|
|
|
break
|
2020-02-24 23:46:48 +08:00
|
|
|
if index == -1:
|
2020-05-22 10:03:16 +08:00
|
|
|
return None
|
2020-02-24 23:46:48 +08:00
|
|
|
else:
|
2020-05-22 10:03:16 +08:00
|
|
|
return tblgen_types[i]
|
2019-11-19 10:08:21 +08:00
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
def get_tblgen_type_index(type_str):
|
|
|
|
return tblgen_types.index(type_str)
|
|
|
|
|
|
|
|
#the possible data structures are tensor, map and seq(tensor())
|
2020-06-18 21:49:40 +08:00
|
|
|
def get_data_structure_element(allowed_type_str):
|
|
|
|
structure_list = ['tensor', 'seq', 'map']
|
|
|
|
for structure in structure_list:
|
|
|
|
if allowed_type_str.startswith(structure) :
|
|
|
|
element = allowed_type_str.replace(
|
|
|
|
structure+'(', '', 1).replace(')', '', 1)
|
|
|
|
return (structure, element)
|
|
|
|
return (None, None)
|
2020-02-24 23:46:48 +08:00
|
|
|
|
|
|
|
def get_allowed_elem_types(schema, input):
|
2020-05-22 10:03:16 +08:00
|
|
|
#allowed_types_str = None
|
|
|
|
# return allowed_types_str
|
2020-02-24 23:46:48 +08:00
|
|
|
# TODO: enable type constraints.
|
2020-05-22 10:03:16 +08:00
|
|
|
if input.typeStr :
|
|
|
|
tstr = input.typeStr
|
|
|
|
else :
|
|
|
|
return None
|
|
|
|
if schema.type_constraints:
|
|
|
|
for type_constraint in schema.type_constraints:
|
|
|
|
if type_constraint.type_param_str != tstr :
|
|
|
|
continue
|
|
|
|
allowed_type_list=[]
|
|
|
|
allowedTypes = type_constraint.allowed_type_strs
|
2020-06-18 21:49:40 +08:00
|
|
|
allowed_structure = None
|
2020-05-22 10:03:16 +08:00
|
|
|
for allowedType in allowedTypes:
|
|
|
|
structure, element = get_data_structure_element(allowedType);
|
|
|
|
if structure == None or element == None:
|
2020-06-18 21:49:40 +08:00
|
|
|
return None, None
|
|
|
|
|
|
|
|
if allowed_structure != None and allowed_structure != structure :
|
|
|
|
print("{}: one structure assumed".format(schema.name))
|
|
|
|
sys.exit(-1)
|
|
|
|
allowed_structure = structure
|
2020-05-22 10:03:16 +08:00
|
|
|
t = np_type_to_tblgen_attr_type(element)
|
|
|
|
if t == None :
|
2020-06-18 21:49:40 +08:00
|
|
|
return allowed_structure, None
|
2020-05-22 10:03:16 +08:00
|
|
|
if not t in allowed_type_list :
|
|
|
|
allowed_tyoe_list = allowed_type_list.append(t)
|
2020-06-18 21:49:40 +08:00
|
|
|
|
|
|
|
return allowed_structure,allowed_type_list
|
|
|
|
|
|
|
|
return None, None
|
2020-02-24 23:46:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
def inc_indent(indent=None):
|
|
|
|
return "" if indent is None else indent + ' ' * 2
|
|
|
|
|
|
|
|
|
|
|
|
def dec_indent(indent):
|
|
|
|
return indent[:-2]
|
|
|
|
|
|
|
|
|
|
|
|
def join_args(args):
|
|
|
|
return ", ".join(args)
|
|
|
|
|
|
|
|
def get_operands_or_results(schema, is_input):
|
|
|
|
value_list = schema.inputs if is_input else schema.outputs
|
|
|
|
if not value_list:
|
|
|
|
return OrderedDict()
|
|
|
|
|
|
|
|
def any_type_of(types):
|
|
|
|
assert isinstance(types, list)
|
|
|
|
if len(types) == 1:
|
|
|
|
return types[0]
|
2020-02-21 22:28:24 +08:00
|
|
|
else:
|
2020-02-24 23:46:48 +08:00
|
|
|
return "AnyTypeOf<[{}]>".format(", ".join(types))
|
|
|
|
|
|
|
|
name_to_types = OrderedDict()
|
2020-03-19 15:03:37 +08:00
|
|
|
for i, value in enumerate(value_list):
|
2020-06-18 21:49:40 +08:00
|
|
|
structure, elem_types = get_allowed_elem_types(schema, value)
|
2020-02-24 23:46:48 +08:00
|
|
|
|
2020-06-18 21:49:40 +08:00
|
|
|
if structure == 'tensor' :
|
|
|
|
if elem_types is None:
|
|
|
|
types = ["AnyMemRef", "AnyTensor"]
|
|
|
|
else:
|
|
|
|
elem_types_str = ','.join(elem_types)
|
|
|
|
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
|
|
|
|
types = list(map(lambda x: x.format(elem_types_str), types))
|
|
|
|
elif structure == 'seq' :
|
|
|
|
# Seq is not supported yet.
|
|
|
|
# Use of TensorOf<[AnyTensor]> as a placeholder for tablegen.
|
|
|
|
# When the Operation is used, warning/error will be generated at runtime.
|
|
|
|
if elem_types is None:
|
|
|
|
types = ["AnyMemRef", "TensorOf<[AnyTensor]>"]
|
|
|
|
else:
|
|
|
|
elem_types_str = ','.join(elem_types)
|
|
|
|
types = ["TensorOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"]
|
|
|
|
types = list(map(lambda x: x.format(elem_types_str), types))
|
|
|
|
elif structure == 'map' :
|
|
|
|
# Map is not supported yet.
|
|
|
|
# Use of TupleOf as a placeholder for tablegen.
|
|
|
|
# When the Operation is used, warning/error will be generated at runtime.
|
|
|
|
if elem_types is None:
|
|
|
|
types = ["AnyMemRef", "TupleOf<[AnyTensor]>"]
|
|
|
|
else:
|
|
|
|
elem_types_str = ','.join(elem_types)
|
|
|
|
types = ["TupleOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"]
|
|
|
|
types = list(map(lambda x: x.format(elem_types_str), types))
|
2020-02-24 23:46:48 +08:00
|
|
|
else:
|
2020-06-18 21:49:40 +08:00
|
|
|
types = ["AnyMemRef", "AnyTensor"]
|
2020-02-24 23:46:48 +08:00
|
|
|
|
2020-03-19 15:03:37 +08:00
|
|
|
# If operand is promotable to an attribute, then it must be
|
|
|
|
# nullable in case it migrates to be an attribute.
|
|
|
|
if schema.name in OpsWithPromotableConstOperands:
|
|
|
|
idxs = dict(OpsWithPromotableConstOperands[schema.name]).values()
|
2020-05-15 13:19:28 +08:00
|
|
|
if i in idxs and not OpSchema.FormalParameterOption.Optional == value.option:
|
2020-03-19 15:03:37 +08:00
|
|
|
types.append("NoneType")
|
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
if OpSchema.FormalParameterOption.Optional == value.option:
|
|
|
|
types.append("NoneType")
|
|
|
|
elif OpSchema.FormalParameterOption.Variadic == value.option:
|
|
|
|
if value.isHomogeneous:
|
|
|
|
types = ["Variadic<{}>".format(any_type_of(types))]
|
2019-11-19 10:08:21 +08:00
|
|
|
else:
|
2020-02-24 23:46:48 +08:00
|
|
|
#TODO handle(variadic, heterogeneous) "
|
2020-04-08 15:00:34 +08:00
|
|
|
sys.stderr.write("warning: (variadic, heterogeneous) for" + schema.name +
|
|
|
|
' ' + value.name + "\n")
|
2020-02-24 23:46:48 +08:00
|
|
|
|
|
|
|
# Since output name can coincide with that of an input, we explicitly
|
|
|
|
# append a suffix "_out" to such names for disambiguation.
|
|
|
|
if is_input:
|
|
|
|
value_name = value.name
|
2020-02-21 22:28:24 +08:00
|
|
|
else:
|
2020-02-24 23:46:48 +08:00
|
|
|
value_name = get_unique_output_name(schema, value.name)
|
|
|
|
|
|
|
|
name_to_types[value_name] = any_type_of(types)
|
|
|
|
return name_to_types
|
|
|
|
|
|
|
|
|
|
|
|
def get_attrs(schema):
|
|
|
|
def get_attr_type_optional(attr_type):
|
|
|
|
return 'OptionalAttr<{}>'.format(
|
|
|
|
onnx_attr_type_to_mlir_attr_type(attr_type))
|
|
|
|
|
|
|
|
def get_attr_type_with_default(attr_type, attr_default):
|
|
|
|
return 'DefaultValuedAttr<{}, "{}">'.format(
|
|
|
|
onnx_attr_type_to_mlir_attr_type(attr_type), attr_default)
|
|
|
|
|
|
|
|
if not schema.attributes:
|
|
|
|
return OrderedDict()
|
|
|
|
|
|
|
|
name_to_type = OrderedDict()
|
|
|
|
for _, attr in sorted(schema.attributes.items()):
|
|
|
|
qualified_attr_name = "{}.{}".format(schema.name, attr.name)
|
|
|
|
if qualified_attr_name in special_attr_defaults:
|
|
|
|
name_to_type[attr.name] = get_attr_type_with_default(
|
|
|
|
*special_attr_defaults[qualified_attr_name])
|
|
|
|
|
|
|
|
# option holds either required or default value
|
|
|
|
elif attr.required:
|
|
|
|
name_to_type[attr.name] = onnx_attr_type_to_mlir_attr_type(
|
|
|
|
attr.type)
|
|
|
|
elif attr.default_value.name:
|
|
|
|
|
|
|
|
def format_value(value): # type: (Any) -> Text
|
|
|
|
if isinstance(value, float):
|
|
|
|
formatted = str(np.round(value, 5))
|
|
|
|
# use default formatting, unless too long.
|
|
|
|
if (len(formatted) > 10):
|
|
|
|
formatted = str("({:e})".format(value))
|
|
|
|
return formatted
|
|
|
|
elif isinstance(
|
|
|
|
value,
|
|
|
|
(bytes, bytearray)) and sys.version_info[0] == 3:
|
|
|
|
return str(value.decode('utf-8'))
|
|
|
|
return str(value)
|
|
|
|
|
|
|
|
default_value = helper.get_attribute_value(attr.default_value)
|
|
|
|
if isinstance(default_value, list):
|
|
|
|
default_value = [format_value(val) for val in default_value]
|
|
|
|
default_value_str = '{}'.format(default_value)
|
|
|
|
default_value_str = default_value_str.replace('[', '{', 1)
|
|
|
|
default_value_str = default_value_str.replace(']', '}', 1)
|
|
|
|
if Text(attr.type) == "AttrType.STRINGS":
|
|
|
|
default_value_str = default_value_str.replace("'", '\\"')
|
|
|
|
else:
|
|
|
|
default_value_str = default_value_str.replace("'", '')
|
|
|
|
else:
|
|
|
|
default_value = format_value(default_value)
|
|
|
|
default_value_str = default_value
|
|
|
|
|
|
|
|
name_to_type[attr.name] = get_attr_type_with_default(
|
|
|
|
attr.type, default_value_str)
|
|
|
|
else:
|
|
|
|
name_to_type[attr.name] = get_attr_type_optional(attr.type)
|
|
|
|
return name_to_type
|
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
def get_numberof_list(mylist):
|
|
|
|
expected_num = len(mylist)
|
|
|
|
for element in mylist :
|
|
|
|
if OpSchema.FormalParameterOption.Variadic == element.option:
|
|
|
|
expected_num = -1
|
|
|
|
return expected_num
|
|
|
|
|
|
|
|
def get_output_type_mapping(schema):
|
|
|
|
mapping=[]
|
|
|
|
for output in schema.outputs :
|
|
|
|
#if only one type is allowed, just set that
|
2020-06-18 21:49:40 +08:00
|
|
|
structure, allowed_elem_types = get_allowed_elem_types(schema, output)
|
2020-05-22 10:03:16 +08:00
|
|
|
if allowed_elem_types != None and len(allowed_elem_types) == 1 :
|
|
|
|
mapping.append(str(get_tblgen_type_index(allowed_elem_types[0])))
|
|
|
|
continue
|
|
|
|
|
|
|
|
#map the type string
|
|
|
|
if output.typeStr :
|
|
|
|
tstr = output.typeStr
|
|
|
|
found = False
|
|
|
|
for i, input in enumerate(schema.inputs):
|
|
|
|
if input.typeStr and input.typeStr == tstr:
|
|
|
|
mapping.append(str(i+MAX_NUM_TYPES))
|
|
|
|
found = True
|
|
|
|
break
|
|
|
|
if found:
|
|
|
|
continue
|
|
|
|
|
|
|
|
#unknown output type
|
|
|
|
mapping.append(str(-1))
|
2020-06-09 14:55:49 +08:00
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
return mapping
|
2020-06-09 14:55:49 +08:00
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
def get_numberof_inout(s, indent, schema):
|
|
|
|
expected_num_operands = get_numberof_list(schema.inputs)
|
|
|
|
indent = inc_indent(indent)
|
|
|
|
s += indent + "static int getNumberOfOperands() {\n"
|
|
|
|
indent = inc_indent(indent)
|
|
|
|
s += indent + "return {};\n".format(expected_num_operands)
|
|
|
|
indent = dec_indent(indent)
|
|
|
|
s += indent + "}\n"
|
|
|
|
|
|
|
|
expected_num_results = get_numberof_list(schema.outputs)
|
|
|
|
s += indent + "static int getNumberOfResults() {\n"
|
|
|
|
indent = inc_indent(indent)
|
|
|
|
s += indent + "return {};\n".format(expected_num_results)
|
|
|
|
indent = dec_indent(indent)
|
|
|
|
s += indent + "}\n"
|
|
|
|
|
|
|
|
s += indent + "static std::vector<int> getTypeMap() {\n"
|
|
|
|
mapping = get_output_type_mapping(schema)
|
|
|
|
indent = inc_indent(indent)
|
|
|
|
s += indent + "return {" + ",".join(mapping) + "};\n"
|
|
|
|
indent = dec_indent(indent)
|
|
|
|
s += indent + "}\n"
|
|
|
|
|
|
|
|
return s
|
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
|
2020-03-19 15:03:37 +08:00
|
|
|
def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
|
|
|
|
cpp_name_to_idx_literal = "{" + ", ".join([
|
|
|
|
"{{\"{}\", {}}}".format(*name_to_idx)
|
|
|
|
for name_to_idx in const_operands_name_to_idx
|
|
|
|
]) + "}"
|
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
#s += indent + "let extraClassDeclaration = [{\n"
|
2020-03-19 15:03:37 +08:00
|
|
|
indent = inc_indent(indent)
|
|
|
|
s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n"
|
|
|
|
indent = inc_indent(indent)
|
|
|
|
s += indent + "return {};\n".format(cpp_name_to_idx_literal)
|
|
|
|
indent = dec_indent(indent)
|
|
|
|
s += indent + "}\n"
|
2020-05-22 10:03:16 +08:00
|
|
|
#indent = dec_indent(indent)
|
|
|
|
#s += indent + "}];\n"
|
2020-03-19 15:03:37 +08:00
|
|
|
|
|
|
|
return s
|
|
|
|
|
2020-05-26 09:54:19 +08:00
|
|
|
def get_type_inference_func(s, indent, type_inference_code):
|
|
|
|
indent = inc_indent(indent)
|
|
|
|
|
|
|
|
s += indent + "std::vector<mlir::Type> resultTypeInference() {" + "\n"
|
|
|
|
indent = inc_indent(indent)
|
|
|
|
s += indent + "std::vector<mlir::Type> resultTypes;" + "\n"
|
|
|
|
|
|
|
|
s += indent + type_inference_code + '\n'
|
|
|
|
|
|
|
|
s += indent + "return resultTypes;" + "\n"
|
|
|
|
indent = dec_indent(indent)
|
|
|
|
s += indent + "}" + "\n"
|
|
|
|
|
|
|
|
indent = dec_indent(indent)
|
|
|
|
return s
|
2020-06-09 14:55:49 +08:00
|
|
|
|
|
|
|
|
2020-03-19 15:03:37 +08:00
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
def gen_op_def(schema):
|
|
|
|
indent = inc_indent()
|
2020-06-23 08:01:56 +08:00
|
|
|
s = 'def ONNX{0}Op:ONNX_Op<"{0}",\n'.format(schema.name)
|
2020-02-24 23:46:48 +08:00
|
|
|
|
|
|
|
# Generate decl for op traits.
|
|
|
|
traits = ["NoSideEffect"]
|
|
|
|
if schema.name in OpsWithShapeInference:
|
|
|
|
traits.append("DeclareOpInterfaceMethods<ShapeInferenceOpInterface>")
|
2020-03-19 15:03:37 +08:00
|
|
|
if schema.name in OpsWithPromotableConstOperands.keys():
|
|
|
|
traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">")
|
2020-05-26 09:54:19 +08:00
|
|
|
if schema.name in OpsWithResultTypeInference.keys():
|
|
|
|
traits.append("OpInterface<\"ResultTypeInferenceOpInterface\">")
|
2020-02-24 23:46:48 +08:00
|
|
|
s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits))
|
|
|
|
|
|
|
|
# Generate decl for canonicalizer.
|
|
|
|
indent = inc_indent(indent)
|
|
|
|
if schema.name in OpsWithCanonicalizer:
|
|
|
|
s += indent + 'let hasCanonicalizer = 1;\n'
|
|
|
|
|
|
|
|
# Generate decl for summary.
|
|
|
|
s += indent + 'let summary = "ONNX {} operation";\n'.format(schema.name)
|
|
|
|
|
|
|
|
# Generate description.
|
|
|
|
s += indent + 'let description = [{\n'
|
|
|
|
if schema.doc:
|
|
|
|
lines = schema.doc.lstrip().splitlines()
|
|
|
|
for line in lines:
|
|
|
|
escaped_line = line.replace('"', '\\"')\
|
|
|
|
.replace('}]', '\\}\\]')
|
|
|
|
s += indent + '"{}"\n'.format(escaped_line)
|
|
|
|
s += indent + '}];\n'
|
|
|
|
|
|
|
|
# Generate ins (consisting of operands and attributes).
|
|
|
|
ins = get_operands_or_results(schema, is_input=True)
|
|
|
|
ins.update(get_attrs(schema))
|
|
|
|
ins_strs = ["{1}:${0}".format(*i) for i in ins.items()]
|
|
|
|
s += indent + 'let arguments = (ins {});\n'.format(
|
|
|
|
(',\n' + inc_indent(indent)).join(ins_strs))
|
|
|
|
|
|
|
|
# Generate outs (operation results).
|
|
|
|
outs = get_operands_or_results(schema, is_input=False)
|
|
|
|
outs_strs = ["{1}:${0}".format(*i) for i in outs.items()]
|
|
|
|
s += indent + 'let results = (outs {});\n'.format(
|
|
|
|
(',\n' + inc_indent(indent)).join(outs_strs))
|
2020-02-21 22:28:24 +08:00
|
|
|
|
2020-06-09 03:45:32 +08:00
|
|
|
# custom_builder_broadcast_ops_list
|
|
|
|
|
2020-02-21 22:28:24 +08:00
|
|
|
# add custom builders
|
|
|
|
# use element type of the first operand to construct an UnrankedTensorType for the output.
|
|
|
|
if schema.name in custom_builder_ops_list:
|
2020-02-24 23:46:48 +08:00
|
|
|
if len(ins) == 0:
|
|
|
|
raise RuntimeWarning(
|
|
|
|
"warning: not generate custom build methods for " +
|
|
|
|
schema.name + " since it does not have operands.")
|
2020-02-21 22:28:24 +08:00
|
|
|
else:
|
2020-02-24 23:46:48 +08:00
|
|
|
s += indent + 'let builders = [\n'
|
|
|
|
# Custom builders with operands and attributes having a seperate parameter.
|
2020-06-09 03:45:32 +08:00
|
|
|
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, Value X,
|
|
|
|
# Value, Y, Attribute A", [{}]>
|
2020-02-24 23:46:48 +08:00
|
|
|
indent = inc_indent(indent)
|
2020-05-20 15:45:42 +08:00
|
|
|
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state'
|
2020-02-24 23:46:48 +08:00
|
|
|
operands_dict = get_operands_or_results(schema, is_input=True)
|
|
|
|
for name, ty in operands_dict.items():
|
|
|
|
s += ', {} {}'.format(tblgen_operand_type_to_cpp_type(ty),
|
|
|
|
name)
|
|
|
|
for name, ty in get_attrs(schema).items():
|
|
|
|
s += ', {} {}'.format(tblgen_attr_type_to_cpp_type(ty), name)
|
2020-02-21 22:28:24 +08:00
|
|
|
s += '", [{\n'
|
2020-02-24 23:46:48 +08:00
|
|
|
indent = inc_indent(indent)
|
|
|
|
|
|
|
|
# Get output type from first operand's type.
|
|
|
|
first_operand_name = list(ins.items())[0][0]
|
2020-06-09 03:45:32 +08:00
|
|
|
build_type_name = ''
|
|
|
|
if schema.name in custom_builder_broadcast_ops_list:
|
|
|
|
second_operand_name = list(ins.items())[1][0]
|
|
|
|
s += indent + 'auto lhsTy = {}.getType().cast<RankedTensorType>();\n'. \
|
|
|
|
format(first_operand_name)
|
|
|
|
s += indent + 'auto rhsTy = {}.getType().cast<RankedTensorType>();\n'. \
|
|
|
|
format(second_operand_name)
|
|
|
|
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
|
|
|
|
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
|
|
|
|
s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';
|
|
|
|
s += indent + indent + 'elementType = {}'.format(first_operand_name) + \
|
|
|
|
'.getType().cast<TensorType>().getElementType();\n';
|
|
|
|
s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n'
|
|
|
|
s += indent + '}\n';
|
|
|
|
build_type_name = 'elementType'
|
|
|
|
else:
|
|
|
|
s += indent + 'auto elementType = {}'.format(first_operand_name) + \
|
|
|
|
'.getType().cast<TensorType>().getElementType();\n'
|
|
|
|
build_type_name = 'UnrankedTensorType::get(elementType)'
|
|
|
|
s += indent + 'build(builder, state, {}'.format(build_type_name)
|
2020-02-24 23:46:48 +08:00
|
|
|
for name, _ in ins.items():
|
|
|
|
s += ', ' + name
|
2020-02-21 22:28:24 +08:00
|
|
|
s += ');\n'
|
2020-02-24 23:46:48 +08:00
|
|
|
indent = dec_indent(indent)
|
|
|
|
s += indent + '}]>,\n'
|
2020-02-21 22:28:24 +08:00
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
# Custom builders with all operands and attributes having aggregate parameters.
|
2020-06-09 03:45:32 +08:00
|
|
|
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands,
|
|
|
|
# ArrayRef<NamedAttribute> attributes", [{}]>'
|
|
|
|
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state, ' + \
|
|
|
|
'ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
|
2020-02-24 23:46:48 +08:00
|
|
|
indent = inc_indent(indent)
|
2020-06-09 03:45:32 +08:00
|
|
|
if schema.name in custom_builder_broadcast_ops_list:
|
|
|
|
s += indent + 'auto lhsTy = operands[0].getType().cast<RankedTensorType>();\n'
|
|
|
|
s += indent + 'auto rhsTy = operands[1].getType().cast<RankedTensorType>();\n'
|
|
|
|
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
|
|
|
|
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
|
|
|
|
s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';
|
|
|
|
s += indent + indent + 'elementType = operands[0]' + \
|
|
|
|
'.getType().cast<TensorType>().getElementType();\n';
|
|
|
|
s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n'
|
|
|
|
s += indent + '}\n';
|
|
|
|
else:
|
|
|
|
s += indent + 'auto elementType = operands[0].getType().' + \
|
|
|
|
'cast<TensorType>().getElementType();\n'
|
2020-02-24 23:46:48 +08:00
|
|
|
s += indent + 'std::vector<mlir::Type> outputTypes;\n'
|
2020-06-09 03:45:32 +08:00
|
|
|
s += indent + 'outputTypes.emplace_back({});\n'.format(build_type_name)
|
2020-02-24 23:46:48 +08:00
|
|
|
s += indent + 'build(builder, state, outputTypes, operands, attributes);\n'
|
|
|
|
indent = dec_indent(indent)
|
|
|
|
s += indent + '}]>'
|
2020-02-21 22:28:24 +08:00
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
s += '\n' + indent + '];\n'
|
2019-11-19 10:08:21 +08:00
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
# generate extracClassDeclaration
|
|
|
|
s += indent + "let extraClassDeclaration = [{\n"
|
|
|
|
#indent = inc_indent(indent)
|
|
|
|
|
|
|
|
# generate input/output number
|
|
|
|
s = get_numberof_inout(s, indent, schema)
|
|
|
|
|
2020-06-09 14:55:49 +08:00
|
|
|
# generate ProtableConst
|
2020-03-19 15:03:37 +08:00
|
|
|
if schema.name in OpsWithPromotableConstOperands:
|
|
|
|
s = get_promotable_const_operands_func(
|
|
|
|
s, indent, OpsWithPromotableConstOperands[schema.name])
|
2020-05-15 13:19:28 +08:00
|
|
|
|
2020-05-26 09:54:19 +08:00
|
|
|
if schema.name in OpsWithResultTypeInference:
|
|
|
|
s = get_type_inference_func(
|
|
|
|
s, indent, OpsWithResultTypeInference[schema.name])
|
|
|
|
|
2020-05-22 10:03:16 +08:00
|
|
|
s += indent + '}];\n'
|
|
|
|
|
2020-05-15 13:19:28 +08:00
|
|
|
if ( schema.name in custom_definition_misc) :
|
2020-05-26 09:54:19 +08:00
|
|
|
s += custom_definition_misc[schema.name] + '\n'
|
2020-05-15 13:19:28 +08:00
|
|
|
|
2020-01-27 23:09:14 +08:00
|
|
|
s += '}\n\n'
|
2019-11-19 10:08:21 +08:00
|
|
|
return s
|
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
|
2019-12-21 14:58:23 +08:00
|
|
|
"""
|
|
|
|
special cases:
|
|
|
|
* Split: attr split default value: sizeof(output1) namely 1
|
|
|
|
* Conv: attr dilations default value is {num_dim of first input - 2, 1}
|
|
|
|
* Conv: attr kernel_shape type is ints
|
|
|
|
* Transpose: attr perm default value is {} empty int list
|
|
|
|
"""
|
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
|
|
|
|
def gen_op_importer(schema, file):
|
|
|
|
indent = inc_indent()
|
|
|
|
s = indent + 'if (opName == "' + schema.name + '")\n'
|
|
|
|
|
|
|
|
expected_num_operands = len(schema.inputs)
|
|
|
|
expected_num_results = len(schema.outputs)
|
2020-01-28 21:48:11 +08:00
|
|
|
for input in schema.inputs:
|
|
|
|
if OpSchema.FormalParameterOption.Variadic == input.option:
|
2020-02-24 23:46:48 +08:00
|
|
|
expected_num_operands = -1
|
2020-01-28 21:48:11 +08:00
|
|
|
for output in schema.outputs:
|
|
|
|
if OpSchema.FormalParameterOption.Variadic == output.option:
|
2020-02-24 23:46:48 +08:00
|
|
|
expected_num_results = -1
|
|
|
|
|
2020-06-23 08:01:56 +08:00
|
|
|
handler_func = special_op_handler.get(
|
|
|
|
schema.name, "buildOperation<mlir::ONNX{}Op>".format(schema.name))
|
2020-02-24 23:46:48 +08:00
|
|
|
|
|
|
|
# Special handlers currently require expected num operands/results to be specified.
|
|
|
|
# TODO: remove special handlers.
|
|
|
|
args = ["node"]
|
2020-05-22 10:03:16 +08:00
|
|
|
"""
|
2020-02-24 23:46:48 +08:00
|
|
|
if expected_num_operands != -1 or expected_num_results != -1 or "buildOperation" not in handler_func:
|
|
|
|
args.append(
|
|
|
|
"/* expected_num_operands = */ {}".format(expected_num_operands))
|
|
|
|
args.append(
|
|
|
|
'/* expected_num_results = */ {}'.format(expected_num_results))
|
2020-05-22 10:03:16 +08:00
|
|
|
"""
|
2020-05-15 13:19:28 +08:00
|
|
|
s += inc_indent(indent) + " {}({});\n".format(
|
2020-02-24 23:46:48 +08:00
|
|
|
handler_func, ", ".join(args))
|
|
|
|
|
|
|
|
file.write(s)
|
|
|
|
|
|
|
|
|
|
|
|
def build_operator_schemas():
|
|
|
|
# domain -> support level -> name -> [schema]
|
|
|
|
index = defaultdict(lambda: defaultdict(lambda: defaultdict(
|
|
|
|
list))) # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
|
|
|
|
for schema in defs.get_all_schemas_with_history():
|
|
|
|
index[schema.domain][int(
|
|
|
|
schema.support_level)][schema.name].append(schema)
|
|
|
|
|
|
|
|
# Preprocess the Operator Schemas
|
|
|
|
# [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
|
|
|
|
operator_schemas = list(
|
|
|
|
) # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
|
|
|
|
exsting_ops = set() # type: Set[Text]
|
|
|
|
for domain, _supportmap in sorted(index.items()):
|
|
|
|
if not should_render_domain(domain):
|
|
|
|
continue
|
|
|
|
processed_supportmap = list()
|
|
|
|
for _support, _namemap in sorted(_supportmap.items()):
|
|
|
|
processed_namemap = list()
|
|
|
|
for n, unsorted_versions in sorted(_namemap.items()):
|
|
|
|
versions = sorted(unsorted_versions,
|
|
|
|
key=lambda s: s.since_version)
|
|
|
|
schema = versions[-1]
|
|
|
|
if schema.name in exsting_ops:
|
|
|
|
continue
|
2020-05-21 01:48:45 +08:00
|
|
|
|
|
|
|
if check_operation_version :
|
|
|
|
# Generate operation of the latest version of your onnx.
|
|
|
|
exsting_ops.add(schema.name)
|
|
|
|
processed_namemap.append((n, schema, versions))
|
|
|
|
|
|
|
|
# Add checks against version_dict
|
|
|
|
if schema.name not in version_dict :
|
|
|
|
print("Check-operation-version: Operation {} with version is new".format(
|
|
|
|
schema.since_version, schema.name))
|
|
|
|
elif schema.since_version > version_dict[schema.name]:
|
|
|
|
print("Check-operation-version: Operation {} has a newer version {}"+
|
2020-06-09 14:55:49 +08:00
|
|
|
"(old version {})".format( schema.name,
|
2020-05-21 01:48:45 +08:00
|
|
|
schema.since_version, version_dict[schema.name]))
|
|
|
|
else:
|
|
|
|
# Generate operation according to the version in version_dict.
|
|
|
|
if schema.name not in version_dict :
|
|
|
|
continue
|
|
|
|
found = False
|
|
|
|
for schema in reversed(versions):
|
|
|
|
# Check the version number against the version_dict
|
|
|
|
if schema.since_version == version_dict[schema.name]:
|
|
|
|
exsting_ops.add(schema.name)
|
|
|
|
processed_namemap.append((n, schema, versions))
|
|
|
|
found = True
|
|
|
|
break
|
|
|
|
if not found:
|
|
|
|
print("Your onnx may be too old."
|
|
|
|
"right version for opertion {} not found".format(
|
|
|
|
schema.name))
|
2020-05-26 09:54:19 +08:00
|
|
|
sys.exit()
|
2020-02-24 23:46:48 +08:00
|
|
|
processed_supportmap.append((_support, processed_namemap))
|
|
|
|
operator_schemas.append((domain, processed_supportmap))
|
|
|
|
return operator_schemas
|
|
|
|
|
2019-12-21 14:58:23 +08:00
|
|
|
|
2019-11-19 10:08:21 +08:00
|
|
|
def main(args): # type: (Type[Args]) -> None
|
2020-02-24 23:46:48 +08:00
|
|
|
curr_utc_time = datetime.datetime.now(
|
|
|
|
datetime.timezone.utc).strftime("%m/%d/%Y, %H:%M:%S")
|
|
|
|
autogen_warning = (
|
|
|
|
'//********************************************************\n'
|
|
|
|
'// Do not modify this file directly.\n'
|
|
|
|
'// This file is automatically generated via script.\n'
|
2020-04-09 23:37:04 +08:00
|
|
|
'// Details can be found in docs/readonnxdefs.md .\n'
|
2020-02-24 23:46:48 +08:00
|
|
|
'//********************************************************\n\n')
|
|
|
|
autogen_warning = autogen_warning.format(curr_utc_time)
|
|
|
|
|
2020-04-08 15:00:34 +08:00
|
|
|
op_def = args.op_def
|
2020-02-24 23:46:48 +08:00
|
|
|
op_def.write(autogen_warning)
|
|
|
|
|
2020-04-08 15:00:34 +08:00
|
|
|
op_importer = args.op_importer
|
2020-02-24 23:46:48 +08:00
|
|
|
op_importer.write(autogen_warning)
|
|
|
|
|
2020-05-21 01:48:45 +08:00
|
|
|
version_dict = dict()
|
2020-02-24 23:46:48 +08:00
|
|
|
for domain, supportmap in build_operator_schemas():
|
|
|
|
for _, namemap in supportmap:
|
|
|
|
for op_type, schema, versions in namemap:
|
2020-05-21 01:48:45 +08:00
|
|
|
if check_operation_version:
|
|
|
|
version_dict[schema.name] = schema.since_version
|
|
|
|
else:
|
|
|
|
gen_op_importer(schema, op_importer)
|
|
|
|
r = gen_op_def(schema)
|
|
|
|
op_def.write(r)
|
|
|
|
if check_operation_version :
|
|
|
|
pprint.pprint(version_dict)
|
2019-11-19 10:08:21 +08:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2020-02-24 23:46:48 +08:00
|
|
|
curr_dir = os.path.dirname(os.path.realpath(__file__))
|
2019-11-19 10:08:21 +08:00
|
|
|
|
|
|
|
class Args(object):
|
2020-04-08 15:00:34 +08:00
|
|
|
if args.dry_run_onnx_ops:
|
|
|
|
op_def = StringIO()
|
|
|
|
else:
|
2020-06-23 08:01:56 +08:00
|
|
|
op_def_file_path = os.path.join(curr_dir, 'ONNXOps.td.inc')
|
2020-04-08 15:00:34 +08:00
|
|
|
op_def = io.open(op_def_file_path, 'w', newline='')
|
2020-02-24 23:46:48 +08:00
|
|
|
|
2020-04-08 15:00:34 +08:00
|
|
|
if args.dry_run_op_build_table:
|
|
|
|
op_importer = StringIO()
|
|
|
|
else:
|
2020-06-23 08:01:56 +08:00
|
|
|
op_importer_file_path = os.path.join(curr_dir, 'OpBuildTable.inc')
|
2020-04-08 15:00:34 +08:00
|
|
|
op_importer = io.open(op_importer_file_path, 'w', newline='')
|
2019-11-19 10:08:21 +08:00
|
|
|
main(Args)
|
2020-04-08 15:00:34 +08:00
|
|
|
|
|
|
|
if args.dry_run_onnx_ops:
|
|
|
|
sys.stdout.write(Args.op_def.getvalue())
|
|
|
|
if args.dry_run_op_build_table:
|
|
|
|
sys.stdout.write(Args.op_importer.getvalue())
|
|
|
|
|