diff --git a/.circleci/config.yml b/.circleci/config.yml
index 48fda88..3863f72 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -38,7 +38,7 @@ jobs:
- run:
name: Run End-To-End Tests
command: |
- sudo pip install -q onnx
+ sudo pip install -q -e ./ONNF/third_party/onnx
cd ONNF/build
cmake --build . --target run-onnx-backend-test
- run:
diff --git a/.clang-format b/.clang-format
index a74fda4..b3276c6 100644
--- a/.clang-format
+++ b/.clang-format
@@ -1,2 +1,3 @@
BasedOnStyle: LLVM
AlwaysBreakTemplateDeclarations: Yes
+AlignAfterOpenBracket: DontAlign
diff --git a/.gitignore b/.gitignore
index 259148f..7f8814f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -30,3 +30,145 @@
*.exe
*.out
*.app
+
+# Filesystem
+.DS_Store
+
+# The following .gitignore content is taken from
+# https://github.com/github/gitignore/blob/master/Python.gitignore
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
diff --git a/doc/Dialects/onnx.md b/doc/Dialects/onnx.md
index cc52df6..e4ca150 100644
--- a/doc/Dialects/onnx.md
+++ b/doc/Dialects/onnx.md
@@ -327,10 +327,10 @@ ONNX BatchNormalization operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
-1. `out_mean`: memref of any type values or tensor of any type values
-1. `out_var`: memref of any type values or tensor of any type values
-1. `saved_mean`: memref of any type values or tensor of any type values
-1. `saved_var`: memref of any type values or tensor of any type values
+1. `out_mean`: memref of any type values or tensor of any type values or none type
+1. `out_var`: memref of any type values or tensor of any type values or none type
+1. `saved_mean`: memref of any type values or tensor of any type values or none type
+1. `saved_var`: memref of any type values or tensor of any type values or none type
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
ONNX BatchNormalization operation in test mode
@@ -375,12 +375,12 @@ ONNX BitShift operation
"Bitwise shift operator performs element-wise operation. For each input element, if the"
-" attribute "direction" is "RIGHT", this operator moves its binary representation toward"
-" the right side so that the input value is effectively decreased. If the attribute "direction""
-" is "LEFT", bits of binary representation moves toward the left side, which results the"
+" attribute \"direction\" is \"RIGHT\", this operator moves its binary representation toward"
+" the right side so that the input value is effectively decreased. If the attribute \"direction\""
+" is \"LEFT\", bits of binary representation moves toward the left side, which results the"
" increase of its actual value. The input X is the tensor to be shifted and another input"
-" Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4],"
-" and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with"
+" Y specifies the amounts of shifting. For example, if \"direction\" is \"Right\", X is [1, 4],"
+" and S is [1, 1], the corresponding output Z would be [0, 2]. If \"direction\" is \"LEFT\" with"
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
" "
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
@@ -413,15 +413,15 @@ ONNX Cast operation
"the converted type. The 'to' argument must be one of the data types specified"
"in the 'DataType' enum field in the TensorProto message."
""
-"Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations"
-"(e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may"
+"Casting from string tensor in plain (e.g., \"3.14\" and \"1000\") and scientific numeric representations"
+"(e.g., \"1e-5\" and \"1E8\") to float types is supported. For example, converting string \"100.5\" to an integer may"
"result 100. There are some string literals reserved for special floating-point values;"
-""+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively."
-"Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly,"
-"this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors"
-"to string tensors, plain floating-point representation (such as "314.15926") would be used. "
-"Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases "
-"of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior."
+"\"+INF\" (and \"INF\"), \"-INF\", and \"NaN\" are positive infinity, negative infinity, and not-a-number, respectively."
+"Any string which can exactly match \"+INF\" in a case-insensitive way would be mapped to positive infinite. Similarly,"
+"this case-insensitive rule is applied to \"INF\" and \"NaN\". When casting from numeric tensors"
+"to string tensors, plain floating-point representation (such as \"314.15926\") would be used. "
+"Converting non-numerical-literal string such as \"Hello World!\" is an undefined behavior. Cases "
+"of converting string representing floating-point arithmetic value, such as \"2.718\", to INT is an undefined behavior."
""
"Conversion from a numerical type to any numerical type is always allowed."
"User must be aware of precision loss and value change caused by range difference between two types."
@@ -476,8 +476,8 @@ ONNX Clip operation
#### Operands:
1. `input`: memref of any type values or tensor of any type values
-1. `min`: memref of any type values or tensor of any type values
-1. `max`: memref of any type values or tensor of any type values
+1. `min`: memref of any type values or tensor of any type values or none type
+1. `max`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -618,8 +618,8 @@ ONNX ConvInteger operation
1. `x`: memref of any type values or tensor of any type values
1. `w`: memref of any type values or tensor of any type values
-1. `x_zero_point`: memref of any type values or tensor of any type values
-1. `w_zero_point`: memref of any type values or tensor of any type values
+1. `x_zero_point`: memref of any type values or tensor of any type values or none type
+1. `w_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -678,7 +678,7 @@ ONNX Conv operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
-1. `B`: memref of any type values or tensor of any type values
+1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -720,7 +720,7 @@ ONNX ConvTranspose operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
-1. `B`: memref of any type values or tensor of any type values
+1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -884,7 +884,7 @@ ONNX DequantizeLinear operation
1. `x`: memref of any type values or tensor of any type values
1. `x_scale`: memref of any type values or tensor of any type values
-1. `x_zero_point`: memref of any type values or tensor of any type values
+1. `x_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -964,7 +964,7 @@ ONNX Dropout operation
#### Results:
1. `output`: memref of any type values or tensor of any type values
-1. `mask`: memref of any type values or tensor of any type values
+1. `mask`: memref of any type values or tensor of any type values or none type
### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp)
ONNX DynamicQuantizeLinear operation
@@ -1297,9 +1297,9 @@ ONNX GRU operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
-1. `B`: memref of any type values or tensor of any type values
-1. `sequence_lens`: memref of any type values or tensor of any type values
-1. `initial_h`: memref of any type values or tensor of any type values
+1. `B`: memref of any type values or tensor of any type values or none type
+1. `sequence_lens`: memref of any type values or tensor of any type values or none type
+1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -1315,8 +1315,8 @@ ONNX GRU operation
#### Results:
-1. `Y`: memref of any type values or tensor of any type values
-1. `Y_h`: memref of any type values or tensor of any type values
+1. `Y`: memref of any type values or tensor of any type values or none type
+1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.GatherElements (ONNXGatherElementsOp)
ONNX GatherElements operation
@@ -1558,33 +1558,6 @@ ONNX Gather operation
1. `output`: memref of any type values or tensor of any type values
-### onnx.GemmNoBias (ONNXGemmNoBiasOp)
-ONNX general matrix multiply operation without bias.
-
-#### Description:
-
-
-The "onnx.Gemm" generic matrix multiplication without bias.
-
-
-#### Operands:
-
-1. `A`: memref of any type values or tensor of any type values
-1. `B`: memref of any type values or tensor of any type values
-
-#### Attributes:
-
-| Attribute | MLIR Type | Description |
-| :-------: | :-------: | ----------- |
-| `alpha` | `FloatAttr` | 32-bit float attribute attribute |
-| `beta` | `FloatAttr` | 32-bit float attribute attribute |
-| `transA` | `IntegerAttr` | 64-bit integer attribute attribute |
-| `transB` | `IntegerAttr` | 64-bit integer attribute attribute |
-
-#### Results:
-
-1. `o_Y`: memref of any type values or tensor of any type values
-
### onnx.Gemm (ONNXGemmOp)
ONNX Gemm operation
@@ -1609,7 +1582,7 @@ ONNX Gemm operation
1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
-1. `C`: memref of any type values or tensor of any type values
+1. `C`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -2013,11 +1986,11 @@ ONNX LSTM operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
-1. `B`: memref of any type values or tensor of any type values
-1. `sequence_lens`: memref of any type values or tensor of any type values
-1. `initial_h`: memref of any type values or tensor of any type values
-1. `initial_c`: memref of any type values or tensor of any type values
-1. `P`: memref of any type values or tensor of any type values
+1. `B`: memref of any type values or tensor of any type values or none type
+1. `sequence_lens`: memref of any type values or tensor of any type values or none type
+1. `initial_h`: memref of any type values or tensor of any type values or none type
+1. `initial_c`: memref of any type values or tensor of any type values or none type
+1. `P`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -2033,9 +2006,9 @@ ONNX LSTM operation
#### Results:
-1. `Y`: memref of any type values or tensor of any type values
-1. `Y_h`: memref of any type values or tensor of any type values
-1. `Y_c`: memref of any type values or tensor of any type values
+1. `Y`: memref of any type values or tensor of any type values or none type
+1. `Y_h`: memref of any type values or tensor of any type values or none type
+1. `Y_c`: memref of any type values or tensor of any type values or none type
### onnx.LeakyRelu (ONNXLeakyReluOp)
ONNX LeakyRelu operation
@@ -2160,24 +2133,24 @@ ONNX Loop operation
""
" Operator inputs defined as (max_trip_count, condition_var)."
""
-" input ("", ""):"
+" input (\"\", \"\"):"
" for (int i=0; ; ++i) {"
" cond = ... // Note this value is ignored, but is required in the body"
" }"
""
-" input ("", cond) // Note this is analogous to a while loop"
+" input (\"\", cond) // Note this is analogous to a while loop"
" bool cond = ...;"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
-" input ("", 1) // Note this is analogous to a do-while loop"
+" input (\"\", 1) // Note this is analogous to a do-while loop"
" bool cond = true"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
-" input (trip_count, "") // Note this is analogous to a for loop"
+" input (trip_count, \"\") // Note this is analogous to a for loop"
" int trip_count = ..."
" for (int i=0; i < trip_count; ++i) {"
" cond = ...; // ignored"
@@ -2203,15 +2176,15 @@ ONNX Loop operation
" }"
""
" graph body-net ("
-" %i[INT32, scalar] // iteration number"
-" %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used"
-" %b_in[INT32, scalar] // incoming value of loop-carried-dependency b"
+" %i[INT32, scalar]"
+" %keepgoing[BOOL, scalar]"
+" %b[INT32, scalar]"
" ) {"
-" %my_local = Add(%a, %b_in)"
-" %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b"
-" %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition"
-" %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated"
-" return %keepgoing_out, %b_out, %user_defined_val"
+" %my_local = Add(%a, %b)"
+" %b_out = Sub(%a, %b)"
+" %keepgoing_out = Greater(%my_local, %b_out)"
+" %user_defined_vals = Add(%b, %b)"
+" return %keepgoing_out, %b_out, %user_defined_vals"
" }"
""
"*Sample equivalent C code*"
@@ -2226,51 +2199,31 @@ ONNX Loop operation
" const int max_trip_count = 10; // Analogous to input M"
" int user_defined_vals[]; // Imagine this is resizable"
" /* End implicitly-defined code */"
-" /* initialize loop-carried variables and scan-output variables */"
-" bool keepgoing_out = keepgoing"
-" int b_out = b"
-""
-" for (int i=0; i < max_trip_count && keepgoing_out; ++i) {"
-" /* Implicitly-defined code: bind actual parameter values"
-" to formal parameter variables of loop-body */"
-" bool keepgoing_in = keepgoing_out; "
-" bool b_in = b_out;"
-""
+" for (int i=0; i < max_trip_count && keepgoing; ++i) {"
" /* User-defined code (loop body) */"
-" int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine"
-" b_out = a - b_in;"
-" keepgoing_out = my_local > b_out; "
-" user_defined_val = b_in + b_in; // b_in and b_out are different variables"
+" int my_local = a + b; // Reading values in the enclosing scope is fine"
+" b = a - b; // writes fine if we specify b as a loop-carried dependency"
+" keepgoing = my_local > b; // keepgoing is a loop-carried dependency"
+" user_defined_vals[i] = b + b;"
" /* End user-defined code */"
-""
-" /* Implicitly defined-code */"
-" user_defined_vals[i] = user_defined_val // accumulate scan-output values"
" }"
-" // int t = my_local; // Can't do this. my_local is not accessible here."
+" // my_local = 123; // Can't do this. my_local was defined in the the body"
""
-" // The values below are bound to the output variables of the loop and therefore accessible"
-" // b_out; user_defined_vals; keepgoing_out;"
+" // These below values are live-out from the loop and therefore accessible"
+" b_out; user_defined_vals; keepgoing_out;"
" }"
""
"There are several things of note in this code snippet:"
""
-"1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can"
+"1) Values from the enclosing scope (i.e. variable a here) are in scope and can"
" be referenced in the inputs of the loop."
-"2) Any values computed in the loop body that needs to be used in a subsequent"
-" iteration or after the loop are modelled using a pair of variables in the loop-body,"
-" consisting of an input variable (eg., b_in) and an output variable (eg., b_out)."
-" These are referred to as loop-carried dependences. The loop operation node"
-" supplies the input value of the input variable for the first iteration, and"
-" returns the output value of the output variable produced by the final"
-" iteration."
-"3) Scan_output variables are used to implicitly concatenate values computed across"
-" all the iterations. In the above example, the value of user_defined_val computed"
-" over all iterations are concatenated and returned as the value of user_defined_vals"
-" after the loop."
-"4) Values created in the body cannot be accessed in the enclosing scope,"
-" except using the mechanism described above."
+"2) Any variables which you wish to make available in the enclosing scope (i.e."
+" the variables b and keepgoing) must be declared as either loop-carried"
+" dependencies (both at the op inputs and output and at the body net input and"
+" output) or scan_outputs."
+"3) Values created in the body cannot be accessed in the enclosing scope."
""
-"Note that the semantics of this op support "diagonal" or "wavefront" execution."
+"Note that the semantics of this op support \"diagonal\" or \"wavefront\" execution."
"(See Step 3 here for an example:"
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
"Frontends should emit multi-layer RNNs as a series of While operators (with"
@@ -2280,8 +2233,8 @@ ONNX Loop operation
#### Operands:
-1. `M`: memref of any type values or tensor of any type values
-1. `cond`: memref of any type values or tensor of any type values
+1. `M`: memref of any type values or tensor of any type values or none type
+1. `cond`: memref of any type values or tensor of any type values or none type
1. `v_initial`: memref of any type values or tensor of any type values
#### Attributes:
@@ -2360,8 +2313,8 @@ ONNX MatMulInteger operation
1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
-1. `a_zero_point`: memref of any type values or tensor of any type values
-1. `b_zero_point`: memref of any type values or tensor of any type values
+1. `a_zero_point`: memref of any type values or tensor of any type values or none type
+1. `b_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -2444,7 +2397,7 @@ ONNX MaxPool operation
" ```"
" pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]"
" ```"
-" The output of each pooling window is maximum number of elements exclude pad. "
+" The output of each pooling window is maximum number of elements exclude pad."
" "
#### Operands:
@@ -2466,7 +2419,7 @@ ONNX MaxPool operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
-1. `Indices`: memref of any type values or tensor of any type values
+1. `Indices`: memref of any type values or tensor of any type values or none type
### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp)
ONNX MaxPool operation with a single output.
@@ -2552,7 +2505,7 @@ ONNX MaxUnpool operation
1. `X`: memref of any type values or tensor of any type values
1. `I`: memref of any type values or tensor of any type values
-1. `output_shape`: memref of any type values or tensor of any type values
+1. `output_shape`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -2752,9 +2705,9 @@ ONNX NonMaxSuppression operation
1. `boxes`: memref of any type values or tensor of any type values
1. `scores`: memref of any type values or tensor of any type values
-1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values
-1. `iou_threshold`: memref of any type values or tensor of any type values
-1. `score_threshold`: memref of any type values or tensor of any type values
+1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values or none type
+1. `iou_threshold`: memref of any type values or tensor of any type values or none type
+1. `score_threshold`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -3067,7 +3020,7 @@ ONNX Pad operation
1. `data`: memref of any type values or tensor of any type values
1. `pads`: memref of any type values or tensor of any type values
-1. `constant_value`: memref of any type values or tensor of any type values
+1. `constant_value`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -3124,7 +3077,7 @@ ONNX QLinearConv operation
1. `w_zero_point`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values
-1. `B`: memref of any type values or tensor of any type values
+1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -3188,7 +3141,7 @@ ONNX QuantizeLinear operation
1. `x`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values
-1. `y_zero_point`: memref of any type values or tensor of any type values
+1. `y_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -3270,9 +3223,9 @@ ONNX RNN operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
-1. `B`: memref of any type values or tensor of any type values
-1. `sequence_lens`: memref of any type values or tensor of any type values
-1. `initial_h`: memref of any type values or tensor of any type values
+1. `B`: memref of any type values or tensor of any type values or none type
+1. `sequence_lens`: memref of any type values or tensor of any type values or none type
+1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -3287,8 +3240,8 @@ ONNX RNN operation
#### Results:
-1. `Y`: memref of any type values or tensor of any type values
-1. `Y_h`: memref of any type values or tensor of any type values
+1. `Y`: memref of any type values or tensor of any type values or none type
+1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.RandomNormalLike (ONNXRandomNormalLikeOp)
ONNX RandomNormalLike operation
@@ -3813,14 +3766,14 @@ ONNX Resize operation
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
"Each dimension value of the output tensor is:"
-" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified."
+" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified."
#### Operands:
1. `X`: memref of any type values or tensor of any type values
1. `roi`: memref of any type values or tensor of any type values
1. `scales`: memref of any type values or tensor of any type values
-1. `sizes`: memref of any type values or tensor of any type values
+1. `sizes`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -4438,7 +4391,7 @@ ONNX SequenceErase operation
#### Operands:
1. `input_sequence`: memref of any type values or tensor of any type values
-1. `position`: memref of any type values or tensor of any type values
+1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -4463,7 +4416,7 @@ ONNX SequenceInsert operation
1. `input_sequence`: memref of any type values or tensor of any type values
1. `tensor`: memref of any type values or tensor of any type values
-1. `position`: memref of any type values or tensor of any type values
+1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -4680,8 +4633,8 @@ ONNX Slice operation
1. `data`: memref of any type values or tensor of any type values
1. `starts`: memref of any type values or tensor of any type values
1. `ends`: memref of any type values or tensor of any type values
-1. `axes`: memref of any type values or tensor of any type values
-1. `steps`: memref of any type values or tensor of any type values
+1. `axes`: memref of any type values or tensor of any type values or none type
+1. `steps`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -4834,7 +4787,7 @@ ONNX SplitToSequence operation
#### Operands:
1. `input`: memref of any type values or tensor of any type values
-1. `split`: memref of any type values or tensor of any type values
+1. `split`: memref of any type values or tensor of any type values or none type
#### Attributes:
@@ -4902,9 +4855,9 @@ ONNX StringNormalizer operation
"StringNormalization performs string operations for basic cleaning."
"This operator has only one input (denoted by X) and only one output"
"(denoted by Y). This operator first examines the elements in the X,"
-"and removes elements specified in "stopwords" attribute."
+"and removes elements specified in \"stopwords\" attribute."
"After removing stop words, the intermediate result can be further lowercased,"
-"uppercased, or just returned depending the "case_change_action" attribute."
+"uppercased, or just returned depending the \"case_change_action\" attribute."
"This operator only accepts [C]- and [1, C]-tensor."
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
"if input shape is [C] and shape [1, 1] if input shape is [1, C]."
@@ -5034,8 +4987,8 @@ ONNX TfIdfVectorizer operation
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
"Note that we may consider all skips up to S when generating the n-grams."
""
-"The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and"
-"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF","
+"The examples used above are true if mode is \"TF\". If mode is \"IDF\", all the counts larger than 1 would be truncated to 1 and"
+"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is \"TFIDF\","
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
""
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
@@ -5123,9 +5076,9 @@ ONNX TopK operation
" contains the indices of the top k elements (original indices from the input"
" tensor)."
""
-"If "largest" is 1 (the default value) then the k largest elements are returned."
-"If "sorted" is 1 (the default value) then the resulting k elements will be sorted."
-"If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined."
+"If \"largest\" is 1 (the default value) then the k largest elements are returned."
+"If \"sorted\" is 1 (the default value) then the resulting k elements will be sorted."
+"If \"sorted\" is 0, order of returned 'Values' and 'Indices' are undefined."
""
"Given two equivalent values, this operator uses the indices along the axis as"
" a tiebreaker. That is, the element with the lower index will appear first."
@@ -5184,7 +5137,7 @@ ONNX Unique operation
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
"The first output tensor 'Y' contains all unique values or subtensors of the input. "
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
-"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". "
+"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. \". "
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
""
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
@@ -5268,9 +5221,9 @@ ONNX Unique operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
-1. `indices`: memref of any type values or tensor of any type values
-1. `inverse_indices`: memref of any type values or tensor of any type values
-1. `counts`: memref of any type values or tensor of any type values
+1. `indices`: memref of any type values or tensor of any type values or none type
+1. `inverse_indices`: memref of any type values or tensor of any type values or none type
+1. `counts`: memref of any type values or tensor of any type values or none type
### onnx.Unsqueeze (ONNXUnsqueezeOp)
ONNX Unsqueeze operation
diff --git a/doc/gen_doc.py b/doc/gen_doc.py
index d42eb27..1c593a5 100644
--- a/doc/gen_doc.py
+++ b/doc/gen_doc.py
@@ -4,10 +4,11 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
-from collections import defaultdict
+from collections import defaultdict, OrderedDict
import io
import os
import sys
+import datetime
import numpy as np # type: ignore
@@ -17,59 +18,53 @@ from onnx.backend.test.case import collect_snippets
from onnx.backend.sample.ops import collect_sample_implementations
from typing import Any, Text, Sequence, Dict, List, Type, Set, Tuple
-
-#controls on ONNF code gen
-#specify attr default value
+# Manual specification of attribute defaults.
special_attr_defaults = dict([
-# ("AveragePool "+"kernel_shape", ('ints', '{}')),
-# ("MaxPool "+"kernel_shape", ('ints', '{}')),
-# ("Cast "+"to", ('int', '0')),
-# ("Concat "+"axis", ('int', '0')),
-# ("Conv "+"group", ('int', '1')),
-# ("Unsqueeze "+"axes", ('ints', '{}')),
-# ("RNN "+"activation_alpha", ('floats', '{}')),
-# ("RNN "+"activation_beta", ('floats', '{}')),
- ])
+ # ("AveragePool.kernel_shape", ('ints', '{}')),
+ # ("MaxPool.kernel_shape", ('ints', '{}')),
+ # ("Cast.to", ('int', '0')),
+ # ("Concat.axis", ('int', '0')),
+ # ("Conv.group", ('int', '1')),
+ # ("Unsqueeze.axes", ('ints', '{}')),
+ # ("RNN.activation_alpha", ('floats', '{}')),
+ # ("RNN.activation_beta", ('floats', '{}')),
+])
-#specify the function name in src/builder/frontend_dialect_transformer.cpp
-#the reason for Conv and MaPool is to handled optional arguments
+# Special operation importing handlers.
special_op_handler = dict([
- ("Conv", "ImportNodeConv"),
- ("MaxPool", "ImportNodeMaxPool"),
- ("BatchNormalization", "ImportNodeBatchNormalization"),
- ("Gemm", "ImportNodeGemm"),
- ("Pad", "ImportNodePad"),
- #("Transpose", "ImportNodeTranspose")
- ])
+ ("Conv", "ImportNodeConv"),
+ ("MaxPool", "ImportNodeMaxPool"),
+ ("BatchNormalization", "ImportNodeBatchNormalization"),
+ ("Pad", "ImportNodePad"),
+ #("Transpose", "ImportNodeTranspose")
+])
-#add an Op in this list if ShapeInterference is defined for this Op
-ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
- 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
- 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
- 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
- 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
- 'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum',
- 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'Sign']
+# Operations supporting shape inference.
+OpsWithShapeInference = [
+ 'Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Add', 'Mul', 'Div',
+ 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm',
+ 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
+ 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
+ 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
+ 'Sign'
+]
-CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
- 'ReduceLogSumExp', 'ReduceSumSquare']
+# Operations supporting canonicalization.
+OpsWithCanonicalizer = [
+ 'Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
+ 'ReduceLogSumExp', 'ReduceSumSquare', 'Gemm'
+]
-#add an Op in this list if the Op needs result type deduction which is required
-#when writing declarative rewriting rules. Deduced type is always
-#an UnrankedTensorType whose element type is the same as the first operand's
-#element type.
-#currenlty, there are only two build methods generated:
-# - one with operands and attributes having a separate parameter, and
-# - one with operands and attributes having aggregated parameters.
+# Add an Op in this list if the Op needs result type deduction which is required
+# when writing declarative rewriting rules. Deduced type is always
+# an UnrankedTensorType whose element type is the same as the first operand's
+# element type.
+#
+# Currenlty, there are only two build methods generated:
+# - one with operands and attributes having a separate parameter, and
+# - one with operands and attributes having aggregated parameters.
custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare']
-manual_code_in_op_def = dict([
- ('DummyExample', ' let extraClassDeclaration = [{ \n'+
- ' static StringRef getPermAttrName() { return "perm"; }\n'+
- ' }];\n')
- ])
-
-
SNIPPETS = collect_snippets()
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
ONNX_ML = not bool(os.getenv('ONNX_ML') == '0')
@@ -77,19 +72,12 @@ ONNX_ML = not bool(os.getenv('ONNX_ML') == '0')
ONNX_ML = False
print("ONNX_ML", ONNX_ML)
-
if ONNX_ML:
ext = '-ml.md'
else:
ext = '.md'
-def display_number(v): # type: (int) -> Text
- if defs.OpSchema.is_infinite(v):
- return '∞'
- return Text(v)
-
-
def should_render_domain(domain): # type: (Text) -> bool
if domain == ONNX_ML_DOMAIN and not ONNX_ML:
return False
@@ -98,13 +86,6 @@ def should_render_domain(domain): # type: (Text) -> bool
return True
-def format_name_with_domain(domain, schema_name): # type: (Text, Text) -> Text
- if domain:
- return '{}.{}'.format(domain, schema_name)
- else:
- return schema_name
-
-
def display_attr_type(v): # type: (OpSchema.AttrType) -> Text
assert isinstance(v, OpSchema.AttrType)
s = Text(v)
@@ -114,354 +95,315 @@ def display_attr_type(v): # type: (OpSchema.AttrType) -> Text
return s
-def display_domain(domain): # type: (Text) -> Text
- if domain:
- return "the '{}' operator set".format(domain)
- else:
- return "the default ONNX operator set"
-
-
-def display_domain_short(domain): # type: (Text) -> Text
- if domain:
- return domain
- else:
- return 'ai.onnx (default)'
-
-
-def display_version_link(name, version): # type: (Text, int) -> Text
- changelog_md = 'Changelog' + ext
- name_with_ver = '{}-{}'.format(name, version)
- return '{}'.format(changelog_md, name_with_ver, name_with_ver)
-
def get_unique_output_name(schema, name):
- for input in schema.inputs :
- if input.name == name :
- return 'out_'+name
+ for input in schema.inputs:
+ if input.name == name:
+ return 'out_' + name
return name
-def display_schema(schema, versions): # type: (OpSchema, Sequence[OpSchema]) -> Text
- s = ''
- # doc
- if schema.doc:
- s += '\n'
- s += '\n'.join(' ' + line
- for line in schema.doc.lstrip().splitlines())
- s += '\n'
+def onnx_attr_type_to_mlir_attr_type(t):
+ onnx_attr_type = Text(t)
+ onnx_attr_type = onnx_attr_type[onnx_attr_type.rfind('.') + 1:].lower()
- # since version
- s += '\n#### Version\n'
- if schema.support_level == OpSchema.SupportType.EXPERIMENTAL:
- s += '\nNo versioning maintained for experimental ops.'
+ if onnx_attr_type == 'int':
+ mlir_attr_type = 'I64Attr'
+ elif onnx_attr_type == 'float':
+ mlir_attr_type = 'F32Attr'
+ elif onnx_attr_type == 'ints':
+ mlir_attr_type = 'I64ArrayAttr'
+ elif onnx_attr_type == 'floats':
+ mlir_attr_type = 'F32ArrayAttr'
+ elif onnx_attr_type == "string":
+ mlir_attr_type = 'StrAttr'
+ elif onnx_attr_type == "strings":
+ mlir_attr_type = 'StrArrayAttr'
else:
- s += '\nThis version of the operator has been ' + ('deprecated' if schema.deprecated else 'available') + ' since version {}'.format(schema.since_version)
- s += ' of {}.\n'.format(display_domain(schema.domain))
- if len(versions) > 1:
- # TODO: link to the Changelog.md
- s += '\nOther versions of this operator: {}\n'.format(
- ', '.join(display_version_link(format_name_with_domain(v.domain, v.name),
- v.since_version) for v in versions[:-1]))
-
- # If this schema is deprecated, don't display any of the following sections
- if schema.deprecated:
- return s
-
- # attributes
- if schema.attributes:
- s += '\n#### Attributes\n\n'
- s += '
\n'
- for _, attr in sorted(schema.attributes.items()):
- # option holds either required or default value
- opt = ''
- if attr.required:
- opt = 'required'
- elif attr.default_value.name:
- default_value = helper.get_attribute_value(attr.default_value)
-
- def format_value(value): # type: (Any) -> Text
- if isinstance(value, float):
- formatted = str(np.round(value, 5))
- # use default formatting, unless too long.
- if (len(formatted) > 10):
- formatted = str("({:e})".format(value))
- return formatted
- elif isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3:
- return str(value.decode('utf-8'))
- return str(value)
-
- if isinstance(default_value, list):
- default_value = [format_value(val) for val in default_value]
- else:
- default_value = format_value(default_value)
- opt = 'default is {}'.format(default_value)
-
- s += '- {} : {}{}
\n'.format(
- attr.name,
- display_attr_type(attr.type),
- ' ({})'.format(opt) if opt else '')
- s += '- {}
\n'.format(attr.description)
- s += '
\n'
-
- # inputs
- s += '\n#### Inputs'
- if schema.min_input != schema.max_input:
- s += ' ({} - {})'.format(display_number(schema.min_input),
- display_number(schema.max_input))
- s += '\n\n'
- if schema.inputs:
- s += '\n'
- for input in schema.inputs:
- option_str = ""
- if OpSchema.FormalParameterOption.Optional == input.option:
- option_str = " (optional)"
- elif OpSchema.FormalParameterOption.Variadic == input.option:
- if input.isHomogeneous:
- option_str = " (variadic)"
- else:
- option_str = " (variadic, heterogeneous)"
- s += '- {}{} : {}
\n'.format(input.name, option_str, input.typeStr)
- s += '- {}
\n'.format(input.description)
- s += '
\n'
-
- # outputs
- s += '\n#### Outputs'
- if schema.min_output != schema.max_output:
- s += ' ({} - {})'.format(display_number(schema.min_output),
- display_number(schema.max_output))
- s += '\n\n'
-
- if schema.outputs:
- s += '\n'
- for output in schema.outputs:
- option_str = ""
- if OpSchema.FormalParameterOption.Optional == output.option:
- option_str = " (optional)"
- elif OpSchema.FormalParameterOption.Variadic == output.option:
- if output.isHomogeneous:
- option_str = " (variadic)"
- else:
- option_str = " (variadic, heterogeneous)"
- s += '- {}{} : {}
\n'.format(get_unique_output_name(schema, output.name), option_str, output.typeStr)
- s += '- {}
\n'.format(output.description)
- s += '
\n'
-
- # type constraints
- s += '\n#### Type Constraints'
- s += '\n\n'
- if schema.type_constraints:
- s += '\n'
- for type_constraint in schema.type_constraints:
- allowedTypes = type_constraint.allowed_type_strs
- if (len(allowedTypes) > 0):
- allowedTypeStr = allowedTypes[0]
- for allowedType in allowedTypes[1:]:
- allowedTypeStr += ', ' + allowedType
- s += '- {} : {}
\n'.format(
- type_constraint.type_param_str, allowedTypeStr)
- s += '- {}
\n'.format(type_constraint.description)
- s += '
\n'
-
- # Function Body
- if schema.has_function: # type: ignore
- s += '\n#### Function\n'
- s += '\nThe Function can be represented as a function.\n'
-
- return s
+ mlir_attr_type = 'AnyAttr'
+ #TODO: tensor and sparse tensor
+ return mlir_attr_type
-def support_level_str(level): # type: (OpSchema.SupportType) -> Text
- return \
- "experimental " if level == OpSchema.SupportType.EXPERIMENTAL else ""
+#TODO: any better way to do this.
+def tblgen_attr_type_to_cpp_type(t):
+ if 'I64Attr' in t:
+ cpp_type = 'IntegerAttr'
+ elif 'F32Attr' in t:
+ cpp_type = 'FloatAttr'
+ elif 'I64ArrayAttr' in t or 'F32ArrayAttr' in t:
+ cpp_type = 'ArrayAttr'
+ elif 'StrAttr' in t:
+ cpp_type = 'StringAttr'
+ elif 'strings' in t:
+ cpp_type = 'ArrayAttr'
+ else:
+ cpp_type = 'Attribute'
+ return cpp_type
-def convert_type(tstr) :
- tfrom = np.array(['bool', 'int8', 'int16', 'int32', 'int64',
- 'unkown', 'float16', 'float', 'double'])
- tto =np.array(['I1', 'I8', 'I16', 'I32', 'I64',
- 'BF16', 'F16', 'F32', 'F64'])
+
+def tblgen_operand_type_to_cpp_type(op_type):
+ if op_type.startswith('Variadic'):
+ mytype = 'ValueRange'
+ else:
+ mytype = 'Value'
+ return mytype
+
+
+def np_type_to_tblgen_attr_type(tstr):
+ tfrom = np.array([
+ 'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
+ 'float', 'double'
+ ])
+ tto = np.array(
+ ['I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64'])
index = -1
- for i in range(len(tfrom)) :
- if tfrom[i] in tstr :
+ for i in range(len(tfrom)):
+ if tfrom[i] in tstr:
index = i
break
- if index == -1 :
+ if index == -1:
print("error", tstr)
return ''
- else :
+ else:
return tto[i]
-def collect_types(schema, input) :
- allowedTypeStr=''
- #first step just ignore the type constraints
- return allowedTypeStr
- if input.typeStr :
- tstr = input.typeStr
- else :
- return allwedTypeStr
- if schema.type_constraints:
- for type_constraint in schema.type_constraints:
- if type_constraint.type_param_str != tstr :
- continue
- allowedTypes = type_constraint.allowed_type_strs
- allowedTypeStr=''
- if (len(allowedTypes) > 0):
- t = convert_type(allowedTypes[0])
- if t == '' :
- return ''
- allowedTypeStr += t
- for allowedType in allowedTypes[1:]:
- t = convert_type(allowedType)
- if t == '' :
- return ''
- if not t in allowedTypeStr :
- allowedTypeStr += ', '+t
- return allowedTypeStr
+def get_allowed_elem_types(schema, input):
+ allowed_types_str = None
+ return allowed_types_str
+ # TODO: enable type constraints.
+ # if input.typeStr :
+ # tstr = input.typeStr
+ # else :
+ # return allwedTypeStr
+ # if schema.type_constraints:
+ # for type_constraint in schema.type_constraints:
+ # if type_constraint.type_param_str != tstr :
+ # continue
+ # allowedTypes = type_constraint.allowed_type_strs
+ # allowedTypeStr=''
+ # if (len(allowedTypes) > 0):
+ # t = convert_type(allowedTypes[0])
+ # if t == '' :
+ # return ''
+ # allowedTypeStr += t
+ # for allowedType in allowedTypes[1:]:
+ # t = convert_type(allowedType)
+ # if t == '' :
+ # return ''
+ # if not t in allowedTypeStr :
+ # allowedTypeStr += ', '+t
+ #
+ # return allowedTypeStr
+ #
+ # return allowedTypeStr
- return allowedTypeStr
-def gen_schema(schema) :
- line_indent = ' '
+def inc_indent(indent=None):
+ return "" if indent is None else indent + ' ' * 2
- #s = 'def ONNX'+schema.name+str(schema.since_version)+'Op:ONNX_Op<"'+schema.name+'", \n'
- s = 'def ONNX'+schema.name+'Op:ONNX_Op<"'+schema.name+'", \n'
- s += line_indent+' [NoSideEffect'
- if schema.name in ShapeInferenceList :
- s+= ', DeclareOpInterfaceMethods'
- s += ']> {'
- if schema.name in CanonicalList:
- s += '\n'+line_indent+'let hasCanonicalizer = 1;'
+def dec_indent(indent):
+ return indent[:-2]
- #summary
- s += '\n'+line_indent
- s += 'let summary = "ONNX '+schema.name+' operation";'
- #description
- s += '\n'+line_indent
- s += 'let description = [{'
- if schema.doc:
- """
- s += '\n'.join(line_indent + line
- for line in schema.doc.lstrip().splitlines())
- """
- for line in schema.doc.lstrip().splitlines():
- line = line.replace('}]', '\}\]')
- s += '\n'+line_indent+' '+'"'+line+'"'
- else :
- s += '\n'+line_indent*2 +'no doc for this op from onnx'
- s += '\n'+line_indent+'}];'
+def join_args(args):
+ return ", ".join(args)
- #input
- s+= '\n'+line_indent+'let arguments = (ins '
- isfirst = True
- # add operands
- operand_ins = get_operand_ins(schema)
- for operand_type, operand_name in operand_ins:
- if not isfirst:
- s+= ',\n '
+
+def get_operands_or_results(schema, is_input):
+ value_list = schema.inputs if is_input else schema.outputs
+ if not value_list:
+ return OrderedDict()
+
+ def any_type_of(types):
+ assert isinstance(types, list)
+ if len(types) == 1:
+ return types[0]
else:
- isfirst = False
- s+=operand_type+':$'+operand_name
+ return "AnyTypeOf<[{}]>".format(", ".join(types))
- # add attributes
- attr_ins = get_attr_ins(schema)
- for attr_type, attr_name in attr_ins:
- if not isfirst:
- s += ',\n '
- else :
- isfirst = False
- s += attr_type+':$'+attr_name
- s+= ');'
+ name_to_types = OrderedDict()
+ for value in value_list:
+ elem_types = get_allowed_elem_types(schema, value)
- #output
- s+= '\n'+line_indent+'let results = (outs '
- if schema.outputs:
- for output in schema.outputs:
- if output != schema.outputs[0] :
- s+= ',\n '
- #need to interpret output.typeStr
- etypes=collect_types(schema, output)
- if etypes == '':
- s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
+ if elem_types is None:
+ types = ["AnyMemRef", "AnyTensor"]
+ else:
+ types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
+ types = list(map(lambda x: x.format(elem_types), types))
+
+ if OpSchema.FormalParameterOption.Optional == value.option:
+ types.append("NoneType")
+ elif OpSchema.FormalParameterOption.Variadic == value.option:
+ if value.isHomogeneous:
+ types = ["Variadic<{}>".format(any_type_of(types))]
else:
- s+= 'TensorOf<['+etypes+']>'
- s += ':$'+get_unique_output_name(schema, output.name)
- s+= ');\n'
+ #TODO handle(variadic, heterogeneous) "
+ print("warning: (variadic, heterogeneous) for" + schema.name +
+ ' ' + value.name)
- #s+= 'let hasCanonicalizer = 1;'
-
- #TODO: any better way to do this.
- def get_attr_type_for_builder(attr_type) :
- if 'I64Attr' in attr_type :
- mytype = 'IntegerAttr'
- elif 'F32Attr' in attr_type :
- mytype = 'FloatAttr'
- elif 'I64ArrayAttr' in attr_type or 'F32ArrayAttr' in attr_type:
- mytype = 'ArrayAttr'
- elif 'StrAttr' in attr_type :
- mytype = 'StringAttr'
- elif 'strings' in attr_type :
- mytype = 'ArrayAttr'
- else :
- mytype ='Attribute'
- return mytype
-
- def get_op_type_for_builder(op_type):
- if op_type.startswith('Variadic'):
- mytype = 'ValueRange'
+ # Since output name can coincide with that of an input, we explicitly
+ # append a suffix "_out" to such names for disambiguation.
+ if is_input:
+ value_name = value.name
else:
- mytype = 'Value'
- return mytype
+ value_name = get_unique_output_name(schema, value.name)
+
+ name_to_types[value_name] = any_type_of(types)
+ return name_to_types
+
+
+def get_attrs(schema):
+ def get_attr_type_optional(attr_type):
+ return 'OptionalAttr<{}>'.format(
+ onnx_attr_type_to_mlir_attr_type(attr_type))
+
+ def get_attr_type_with_default(attr_type, attr_default):
+ return 'DefaultValuedAttr<{}, "{}">'.format(
+ onnx_attr_type_to_mlir_attr_type(attr_type), attr_default)
+
+ if not schema.attributes:
+ return OrderedDict()
+
+ name_to_type = OrderedDict()
+ for _, attr in sorted(schema.attributes.items()):
+ qualified_attr_name = "{}.{}".format(schema.name, attr.name)
+ if qualified_attr_name in special_attr_defaults:
+ name_to_type[attr.name] = get_attr_type_with_default(
+ *special_attr_defaults[qualified_attr_name])
+
+ # option holds either required or default value
+ elif attr.required:
+ name_to_type[attr.name] = onnx_attr_type_to_mlir_attr_type(
+ attr.type)
+ elif attr.default_value.name:
+
+ def format_value(value): # type: (Any) -> Text
+ if isinstance(value, float):
+ formatted = str(np.round(value, 5))
+ # use default formatting, unless too long.
+ if (len(formatted) > 10):
+ formatted = str("({:e})".format(value))
+ return formatted
+ elif isinstance(
+ value,
+ (bytes, bytearray)) and sys.version_info[0] == 3:
+ return str(value.decode('utf-8'))
+ return str(value)
+
+ default_value = helper.get_attribute_value(attr.default_value)
+ if isinstance(default_value, list):
+ default_value = [format_value(val) for val in default_value]
+ default_value_str = '{}'.format(default_value)
+ default_value_str = default_value_str.replace('[', '{', 1)
+ default_value_str = default_value_str.replace(']', '}', 1)
+ if Text(attr.type) == "AttrType.STRINGS":
+ default_value_str = default_value_str.replace("'", '\\"')
+ else:
+ default_value_str = default_value_str.replace("'", '')
+ else:
+ default_value = format_value(default_value)
+ default_value_str = default_value
+
+ name_to_type[attr.name] = get_attr_type_with_default(
+ attr.type, default_value_str)
+ else:
+ name_to_type[attr.name] = get_attr_type_optional(attr.type)
+ return name_to_type
+
+
+def gen_op_def(schema):
+ indent = inc_indent()
+ s = 'def ONNX{0}Op:ONNX_Op<"{0}",\n'.format(schema.name)
+
+ # Generate decl for op traits.
+ traits = ["NoSideEffect"]
+ if schema.name in OpsWithShapeInference:
+ traits.append("DeclareOpInterfaceMethods")
+ s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits))
+
+ # Generate decl for canonicalizer.
+ indent = inc_indent(indent)
+ if schema.name in OpsWithCanonicalizer:
+ s += indent + 'let hasCanonicalizer = 1;\n'
+
+ # Generate decl for summary.
+ s += indent + 'let summary = "ONNX {} operation";\n'.format(schema.name)
+
+ # Generate description.
+ s += indent + 'let description = [{\n'
+ if schema.doc:
+ lines = schema.doc.lstrip().splitlines()
+ for line in lines:
+ escaped_line = line.replace('"', '\\"')\
+ .replace('}]', '\\}\\]')
+ s += indent + '"{}"\n'.format(escaped_line)
+ s += indent + '}];\n'
+
+ # Generate ins (consisting of operands and attributes).
+ ins = get_operands_or_results(schema, is_input=True)
+ ins.update(get_attrs(schema))
+ ins_strs = ["{1}:${0}".format(*i) for i in ins.items()]
+ s += indent + 'let arguments = (ins {});\n'.format(
+ (',\n' + inc_indent(indent)).join(ins_strs))
+
+ # Generate outs (operation results).
+ outs = get_operands_or_results(schema, is_input=False)
+ outs_strs = ["{1}:${0}".format(*i) for i in outs.items()]
+ s += indent + 'let results = (outs {});\n'.format(
+ (',\n' + inc_indent(indent)).join(outs_strs))
# add custom builders
# use element type of the first operand to construct an UnrankedTensorType for the output.
if schema.name in custom_builder_ops_list:
- if len(operand_ins) == 0:
- print("warning: not generate custom build methods for " + schema.name + " since it does not have operands.")
+ if len(ins) == 0:
+ raise RuntimeWarning(
+ "warning: not generate custom build methods for " +
+ schema.name + " since it does not have operands.")
else:
- if get_op_type_for_builder(operand_ins[0][0]) == 'ValueRange':
- first_operand = operand_ins[0][1]+'[0]'
- else:
- first_operand = operand_ins[0][1]
-
- s += line_indent+'let builders = [\n'
-
- # custom builders with operands and attributes having a seperate parameter.
+ s += indent + 'let builders = [\n'
+ # Custom builders with operands and attributes having a seperate parameter.
# E.g. OpBuilder<"Builder *builder, OperationState &state, Value X, Value, Y, Attribute A", [{}]>
- s += line_indent*2+'OpBuilder<"Builder *builder, OperationState &state'
- for arg_type, arg_name in operand_ins:
- s += ', '+get_op_type_for_builder(arg_type)+' '+arg_name
- for attr_type, attr_name in attr_ins:
- s += ', '+get_attr_type_for_builder(attr_type)+' '+attr_name
+ indent = inc_indent(indent)
+ s += indent + 'OpBuilder<"Builder *builder, OperationState &state'
+ operands_dict = get_operands_or_results(schema, is_input=True)
+ for name, ty in operands_dict.items():
+ s += ', {} {}'.format(tblgen_operand_type_to_cpp_type(ty),
+ name)
+ for name, ty in get_attrs(schema).items():
+ s += ', {} {}'.format(tblgen_attr_type_to_cpp_type(ty), name)
s += '", [{\n'
- s += line_indent*3+'auto elementType = '+first_operand+'.getType().cast().getElementType();\n'
- s += line_indent*3+'build(builder, state, UnrankedTensorType::get(elementType)'
- for _, arg_name in operand_ins:
- s += ', '+arg_name
- for _, attr_name in attr_ins:
- s += ', '+attr_name
+ indent = inc_indent(indent)
+
+ # Get output type from first operand's type.
+ first_operand_name = list(ins.items())[0][0]
+ s += indent + 'auto elementType = {}.getType().cast().getElementType();\n'.format(
+ first_operand_name)
+ s += indent + 'build(builder, state, UnrankedTensorType::get(elementType)'
+ for name, _ in ins.items():
+ s += ', ' + name
s += ');\n'
- s += line_indent*2+'}]>,\n'
+ indent = dec_indent(indent)
+ s += indent + '}]>,\n'
- # custom builders with all operands and attributes having aggregate parameters.
+ # Custom builders with all operands and attributes having aggregate parameters.
# E.g. OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{}]>'
- s += line_indent*2+'OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{\n'
- s += line_indent*3+'auto elementType = '+first_operand+'.getType().cast().getElementType();\n'
- s += line_indent*3+'std::vector outputTypes;\n'
- s += line_indent*3+'outputTypes.emplace_back(UnrankedTensorType::get(elementType));\n'
- s += line_indent*3+'build(builder, state, outputTypes, operands, attributes);\n'
- s += line_indent*2+'}]>'
+ s += indent + 'OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{\n'
+ indent = inc_indent(indent)
+ s += indent + 'auto elementType = operands[0].getType().cast().getElementType();\n'
+ s += indent + 'std::vector outputTypes;\n'
+ s += indent + 'outputTypes.emplace_back(UnrankedTensorType::get(elementType));\n'
+ s += indent + 'build(builder, state, outputTypes, operands, attributes);\n'
+ indent = dec_indent(indent)
+ s += indent + '}]>'
- s += '\n'+line_indent+'];\n'
-
- #add special code
- if schema.name in manual_code_in_op_def :
- s += manual_code_in_op_def[schema.name]
+ s += '\n' + indent + '];\n'
s += '}\n\n'
-
return s
+
"""
special cases:
* Split: attr split default value: sizeof(output1) namely 1
@@ -470,328 +412,101 @@ special cases:
* Transpose: attr perm default value is {} empty int list
"""
-def gen_code(schema,fefile) :
- handle_variadic = False
+def gen_op_importer(schema, file):
+ indent = inc_indent()
+ s = indent + 'if (opName == "' + schema.name + '")\n'
- line_indent = ' '
- fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
- op_type_str='mlir::ONNX'+schema.name+'Op'
- if schema.name in special_op_handler :
- fefile.write(' '+special_op_handler[schema.name]+'(node, '
- +str(len(schema.inputs))
- +', ' +str(len(schema.outputs)))
- elif len(schema.outputs) > 1 :
- fefile.write(' '+'ImportNodeMultipleOuts<'+op_type_str+'>(node, '
- +str(len(schema.inputs))
- +', ' +str(len(schema.outputs)))
- else :
- fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, '
- +str(len(schema.inputs))
- +', ' +str(len(schema.outputs)))
-
- variadicIn = 'false'
- variadicOut = 'false'
+ expected_num_operands = len(schema.inputs)
+ expected_num_results = len(schema.outputs)
for input in schema.inputs:
if OpSchema.FormalParameterOption.Variadic == input.option:
- if input.isHomogeneous:
- variadicIn = 'true'
- handle_variadic = True
+ expected_num_operands = -1
for output in schema.outputs:
if OpSchema.FormalParameterOption.Variadic == output.option:
- if output.isHomogeneous:
- variadicOut = 'true'
- if not handle_variadic:
- fefile.write(');\n')
- else:
- fefile.write(', '+variadicIn+', '+variadicOut+');\n')
+ expected_num_results = -1
-def get_operand_ins(schema):
- operand_type_and_name_list = [] # [(optype, opname)]
- if schema.inputs:
- for input in schema.inputs:
- optype = ""
+ handler_func = special_op_handler.get(
+ schema.name, "buildOperation".format(schema.name))
- etypes=collect_types(schema, input)
+ # Special handlers currently require expected num operands/results to be specified.
+ # TODO: remove special handlers.
+ args = ["node"]
+ if expected_num_operands != -1 or expected_num_results != -1 or "buildOperation" not in handler_func:
+ args.append(
+ "/* expected_num_operands = */ {}".format(expected_num_operands))
+ args.append(
+ '/* expected_num_results = */ {}'.format(expected_num_results))
+ s += inc_indent(indent) + "return {}({});\n".format(
+ handler_func, ", ".join(args))
- if OpSchema.FormalParameterOption.Optional == input.option:
- #TODO : handle optional
- print("warning: optional input for"+schema.name+' '+input.name)
- elif OpSchema.FormalParameterOption.Variadic == input.option:
- if input.isHomogeneous:
- optype += 'Variadic<'
- else:
- #TODO handle(variadic, heterogeneous) "
- print("warning: (variadic, heterogeneous) for"+schema.name+' '+input.name)
- if etypes == '':
- optype += 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
- else:
- optype += 'TensorOf<['+etypes+']>'
+ file.write(s)
- if OpSchema.FormalParameterOption.Optional == input.option:
- #TODO : handle optional
- t=''
- elif OpSchema.FormalParameterOption.Variadic == input.option:
- if input.isHomogeneous:
- optype += '>'
- else:
- #TODO handle(variadic, heterogeneous) "
- t=''
- operand_type_and_name_list.append((optype, input.name))
- return operand_type_and_name_list
-def get_attr_ins(schema) :
-
- def get_attr_type_basic(attr_type) :
- if attr_type == 'int' :
- mytype = 'I64Attr'
- elif attr_type == 'float' :
- mytype = 'F32Attr'
- elif attr_type == 'ints' :
- mytype = 'I64ArrayAttr'
- elif attr_type == 'floats' :
- mytype = 'F32ArrayAttr'
- elif attr_type == "string" :
- mytype = 'StrAttr'
- elif attr_type == "strings" :
- mytype = 'StrArrayAttr'
- else :
- mytype ='AnyAttr'
- #TODO: tensor and sparse tensor
- return mytype
+def build_operator_schemas():
+ # domain -> support level -> name -> [schema]
+ index = defaultdict(lambda: defaultdict(lambda: defaultdict(
+ list))) # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
+ for schema in defs.get_all_schemas_with_history():
+ index[schema.domain][int(
+ schema.support_level)][schema.name].append(schema)
- def get_attr_type_optional(attr_type) :
- mytype = 'OptionalAttr<'
- mytype += get_attr_type_basic(attr_type)
- mytype += '>'
- return mytype
+ # Preprocess the Operator Schemas
+ # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
+ operator_schemas = list(
+ ) # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
+ exsting_ops = set() # type: Set[Text]
+ for domain, _supportmap in sorted(index.items()):
+ if not should_render_domain(domain):
+ continue
- def get_attr_type_with_default(attr_type, attr_default) :
- mytype = 'DefaultValuedAttr<'
- mytype += get_attr_type_basic(attr_type)
- mytype += ', "'+attr_default+'">'
- return mytype
+ processed_supportmap = list()
+ for _support, _namemap in sorted(_supportmap.items()):
+ processed_namemap = list()
+ for n, unsorted_versions in sorted(_namemap.items()):
+ versions = sorted(unsorted_versions,
+ key=lambda s: s.since_version)
+ schema = versions[-1]
+ if schema.name in exsting_ops:
+ continue
+ exsting_ops.add(schema.name)
+ processed_namemap.append((n, schema, versions))
+ processed_supportmap.append((_support, processed_namemap))
+ operator_schemas.append((domain, processed_supportmap))
+ return operator_schemas
- attr_type_and_name_list = [] # :: [(attrtype, attrname)]
- attr_line = ''
- if schema.attributes:
- for _, attr in sorted(schema.attributes.items()):
- #attr_line = line_indent+line_indent+line_indent+line_indent
- found = False
- attr_type = ""
- if schema.name+' '+attr.name in special_attr_defaults:
- (attr_type_str, attr_default_str) = special_attr_defaults[schema.name+' '+attr.name]
- attr_type = get_attr_type_with_default(attr_type_str, attr_default_str)
- found = True
- elif attr.required:
- s = Text(attr.type)
- attr_type_str = s[s.rfind('.') + 1:].lower()
- attr_type = get_attr_type_basic(attr_type_str)
- found = True
-
- # option holds either required or default value
- elif attr.default_value.name:
- s = Text(attr.type)
- attr_type_str = s[s.rfind('.') + 1:].lower()
-
- default_value = helper.get_attribute_value(attr.default_value)
- def format_value(value): # type: (Any) -> Text
- if isinstance(value, float):
- formatted = str(np.round(value, 5))
- # use default formatting, unless too long.
- if (len(formatted) > 10):
- formatted = str("({:e})".format(value))
- return formatted
- elif isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3:
- return str(value.decode('utf-8'))
- return str(value)
-
- if isinstance(default_value, list):
- default_value = [format_value(val) for val in default_value]
- attr_option_str = '{}'.format(default_value)
- attr_option_str = attr_option_str.replace('[', '{', 1)
- attr_option_str = attr_option_str.replace(']', '}', 1)
- if attr_type_str == 'strings' :
- attr_option_str = attr_option_str.replace("'", '\\"')
- else :
- attr_option_str = attr_option_str.replace("'", '')
- else:
- default_value = format_value(default_value)
- attr_option_str = default_value
- attr_type = get_attr_type_with_default(attr_type_str, attr_option_str)
- found = True
- else:
- s = Text(attr.type)
- attr_type_str = s[s.rfind('.') + 1:].lower()
- attr_type = get_attr_type_optional(attr_type_str)
- if found:
- attr_type_and_name_list.append((attr_type, attr.name))
- return attr_type_and_name_list
def main(args): # type: (Type[Args]) -> None
- with io.open(args.changelog, 'w', newline='') as fout:
- fout.write('## Operator Changelog\n')
- fout.write(
- "*This file is automatically generated from the\n"
- " [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
- " Do not modify directly and instead edit operator definitions.*\n")
+ curr_utc_time = datetime.datetime.now(
+ datetime.timezone.utc).strftime("%m/%d/%Y, %H:%M:%S")
+ autogen_warning = (
+ '//********************************************************\n'
+ '// This file is generated on UTC-{}.\n'
+ '// Do not modify this file directly.\n'
+ '// This file is automatically generated via script.\n'
+ '// Details can be found in doc/readonnxdefs.md .\n'
+ '//********************************************************\n\n')
+ autogen_warning = autogen_warning.format(curr_utc_time)
- # domain -> version -> [schema]
- dv_index = defaultdict(lambda: defaultdict(list)) # type: Dict[Text, Dict[int, List[OpSchema]]]
- for schema in defs.get_all_schemas_with_history():
- dv_index[schema.domain][schema.since_version].append(schema)
+ op_def = io.open(args.op_def_file, 'w', newline='')
+ op_def.write(autogen_warning)
- fout.write('\n')
+ op_importer = io.open(args.op_importer_file, 'w', newline='')
+ op_importer.write(autogen_warning)
- for domain, versionmap in sorted(dv_index.items()):
- if not should_render_domain(domain):
- continue
-
- s = '# {}\n'.format(display_domain_short(domain))
-
- for version, unsorted_schemas in sorted(versionmap.items()):
- s += '## Version {} of {}\n'.format(version, display_domain(domain))
- for schema in sorted(unsorted_schemas, key=lambda s: s.name):
- name_with_ver = '{}-{}'.format(format_name_with_domain(domain, schema.name),
- schema.since_version)
- s += ('### **{}**' + (' (deprecated)' if schema.deprecated else '') + '\n').format(name_with_ver, name_with_ver)
- s += display_schema(schema, [schema])
- s += '\n'
-
- fout.write(s)
-
- with io.open(args.output, 'w', newline='', encoding="utf-8") as fout:
- fout.write('## Operator Schemas\n')
- fout.write(
- "*This file is automatically generated from the\n"
- " [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
- " Do not modify directly and instead edit operator definitions.*\n")
-
- # domain -> support level -> name -> [schema]
- index = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
- for schema in defs.get_all_schemas_with_history():
- #print("check point 0", schema.name, schema.domain, schema.support_level)
- #gen_schema(schema)
- index[schema.domain][int(schema.support_level)][schema.name].append(schema)
-
- fout.write('\n')
-
- # Preprocess the Operator Schemas
- # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
- operator_schemas = list() # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
- exsting_ops = set() # type: Set[Text]
- for domain, _supportmap in sorted(index.items()):
- if not should_render_domain(domain):
- continue
-
- processed_supportmap = list()
- for _support, _namemap in sorted(_supportmap.items()):
- processed_namemap = list()
- for n, unsorted_versions in sorted(_namemap.items()):
- versions = sorted(unsorted_versions, key=lambda s: s.since_version)
- schema = versions[-1]
- #print("check point 2", schema)
- if schema.name in exsting_ops:
- continue
- exsting_ops.add(schema.name)
- processed_namemap.append((n, schema, versions))
- processed_supportmap.append((_support, processed_namemap))
- operator_schemas.append((domain, processed_supportmap))
-
- # Table of contents
- for domain, supportmap in operator_schemas:
- s = '* {}\n'.format(display_domain_short(domain))
- fout.write(s)
- function_ops = list()
- for _, namemap in supportmap:
- for n, schema, versions in namemap:
- if schema.has_function: # type: ignore
- function_ops.append((n, schema, versions))
- continue
- s = ' * {}{}\n'.format(
- support_level_str(schema.support_level),
- format_name_with_domain(domain, n),
- format_name_with_domain(domain, n))
- fout.write(s)
- if len(function_ops):
- fout.write('\n')
- fout.write(' **Operators with function registered:**\n')
- for n, schema, versions in function_ops:
- s = ' * {}{}\n'.format(
- support_level_str(schema.support_level),
- format_name_with_domain(domain, n),
- format_name_with_domain(domain, n))
- fout.write(s)
-
- fout.write('\n')
- tdfile= io.open(args.tdfile, 'w', newline='')
- tdfile.write('//********************************************************\n'+
- '// Warning: Do not modify this file directly\n'+
- '// This file is automatically generated via script\n'+
- '// Details can be found in doc/readonnxdefs.md\n'+
- '//********************************************************\n\n'
- )
- fefile=io.open('op_build_table.inc', 'w', newline='')
- firstfunc = True
-
- fefile.write('//********************************************************\n'+
- '// Warning: Do not modify this file directly\n'+
- '// This file is automatically generated via script\n'+
- '// Details can be found in doc/readonnxdefs.md\n'+
- '//********************************************************\n\n'
- )
- fefile.write(' '+'if (OpName == "DUMMY") {\n')
- for domain, supportmap in operator_schemas:
- s = '## {}\n'.format(display_domain_short(domain))
- fout.write(s)
-
- for _, namemap in supportmap:
- for op_type, schema, versions in namemap:
- # op_type
- #print("check point 1", schema.name, len(schema.inputs), len(schema.outputs))
- gen_code(schema, fefile)
-
- r = gen_schema(schema)
- tdfile.write(r)
- s = ('### {}**{}**' + (' (deprecated)' if schema.deprecated else '') + '\n').format(
- support_level_str(schema.support_level),
- format_name_with_domain(domain, op_type),
- format_name_with_domain(domain, op_type.lower()),
- format_name_with_domain(domain, op_type))
-
- s += display_schema(schema, versions)
-
- s += '\n\n'
-
- if op_type in SNIPPETS:
- s += '#### Examples\n\n'
- for summary, code in sorted(SNIPPETS[op_type]):
- s += '\n'
- s += '{}
\n\n'.format(summary)
- s += '```python\n{}\n```\n\n'.format(code)
- s += ' \n'
- s += '\n\n'
- if op_type.lower() in SAMPLE_IMPLEMENTATIONS:
- s += '#### Sample Implementation\n\n'
- s += '\n'
- s += '{}
\n\n'.format(op_type)
- s += '```python\n{}\n```\n\n'.format(SAMPLE_IMPLEMENTATIONS[op_type.lower()])
- s += ' \n'
- s += '\n\n'
-
- fout.write(s)
- fefile.write(' }')
- fefile.close()
+ for domain, supportmap in build_operator_schemas():
+ for _, namemap in supportmap:
+ for op_type, schema, versions in namemap:
+ gen_op_importer(schema, op_importer)
+ r = gen_op_def(schema)
+ op_def.write(r)
if __name__ == '__main__':
- base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
- docs_dir = os.path.join(base_dir, 'docs')
- print(docs_dir)
+ curr_dir = os.path.dirname(os.path.realpath(__file__))
class Args(object):
- output = os.path.join(docs_dir, 'Operators' + ext)
- changelog = os.path.join(docs_dir, 'Changelog' + ext)
- tdfile = os.path.join(base_dir, 'onnxop.inc')
- print(Args)
+ op_def_file = os.path.join(curr_dir, 'onnxop.inc')
+ op_importer_file = os.path.join(curr_dir, 'op_build_table.inc')
+
main(Args)
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index d895be5..b210275 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -62,7 +62,21 @@ target_include_directories(onnf_shape_inference
target_link_libraries(onnf_shape_inference ${MLIRLibs})
add_dependencies(onnf_shape_inference gen_krnl_ops)
-add_library(onnf_lower_frontend conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
+add_library(onnf_lower_frontend
+ conversion/onnx_to_krnl/onnx_to_krnl_common.cpp
+ conversion/onnx_to_krnl/onnx_to_krnl_common.hpp
+ conversion/onnx_to_krnl/math/elementwise.cpp
+ conversion/onnx_to_krnl/math/gemm.cpp
+ conversion/onnx_to_krnl/math/matmul.cpp
+ conversion/onnx_to_krnl/math/reduction.cpp
+ conversion/onnx_to_krnl/math/softmax.cpp
+ conversion/onnx_to_krnl/nn/conv.cpp
+ conversion/onnx_to_krnl/nn/normalization.cpp
+ conversion/onnx_to_krnl/tensor/identity.cpp
+ conversion/onnx_to_krnl/tensor/reshape.cpp
+ conversion/onnx_to_krnl/tensor/transpose.cpp
+ conversion/onnx_to_krnl/tensor/unsqueeze.cpp
+ conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
target_include_directories(onnf_lower_frontend
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT})
diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp
index 9cadad8..0efca22 100644
--- a/src/builder/frontend_dialect_transformer.cpp
+++ b/src/builder/frontend_dialect_transformer.cpp
@@ -121,6 +121,7 @@ private:
mlir::MLIRContext &context_;
mlir::ModuleOp module_;
mlir::OpBuilder builder_;
+ mlir::Value none_;
// mapping between string name and symbol
OnnxOnnfSymbolMapping frontend_symbols_;
@@ -188,8 +189,9 @@ private:
}
}
- mlir::Type elementType =
- convertONNXTypeToMLIRType(input.type().tensor_type().elem_type());
+ auto elementOnnxType =
+ (onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
+ mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
llvm::ArrayRef tensor_dims(dims.data(), dims.size());
arg_types.emplace_back(
mlir::RankedTensorType::get(tensor_dims, elementType));
@@ -287,8 +289,8 @@ private:
}
}
- std::vector ImportNodeAttributes(
- const onnx::NodeProto &node) {
+ std::vector
+ ImportNodeAttributes(const onnx::NodeProto &node) {
std::vector attributes;
for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i);
@@ -317,21 +319,11 @@ private:
}
}
- // if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
- // combined with 'if constexpr' the issue is the type of the output is
- // different. alternative way to use variadic output for all the op
-
- /*!
- * Important onnx node which generates only one output
- * @param node onnx node
- * @param nIn number of expected inputs
- * @param nOut number of expected outputs
- * @param attrs list of desription for attributes with format {name, type,
- * default}
- */
template
- void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
- bool variadicIn = false, bool variadicOut = false) {
+ void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
+ int expectedNumResults = -1) {
+ bool variadicIn = expectedNumOperands == -1;
+ bool variadicOut = expectedNumResults == -1;
std::vector inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
@@ -339,6 +331,10 @@ private:
}
}
+ if (!variadicIn)
+ for (auto i = inputs.size(); i < expectedNumOperands; i++)
+ inputs.emplace_back(none_);
+
std::vector outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
@@ -347,49 +343,11 @@ private:
auto attributes = ImportNodeAttributes(node);
- llvm::StringRef OpName = node.op_type();
- if ((variadicIn || nIn == inputs.size()) &&
- (variadicOut || nOut == outputTypes.size())) {
- auto op =
- builder_.create(UnknownLoc(), outputTypes, inputs, attributes);
- frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
- op.getResult());
- } else {
- ImportNodeGeneric(node);
- }
- }
-
- template
- void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut,
- bool variadicIn = false,
- bool variadicOut = false) {
- std::vector inputs;
- for (const auto &item : node.input()) {
- if (frontend_symbols_.ContainKey(legalize_name(item))) {
- inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
- }
- }
-
- std::vector outputTypes;
- for (auto item : node.output()) {
- outputTypes.push_back(
- mlir::UnrankedTensorType::get(builder_.getF32Type()));
- }
-
- auto attributes = ImportNodeAttributes(node);
-
- llvm::StringRef OpName = node.op_type();
-
- if ((variadicIn || nIn == inputs.size()) &&
- (variadicOut || nOut == outputTypes.size())) {
- auto op =
- builder_.create(UnknownLoc(), outputTypes, inputs, attributes);
- for (int i = 0; i < node.output().size(); i++) {
- frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
- op.getResult(i));
- }
- } else {
- ImportNodeGeneric(node);
+ // TODO: Handle optional inputs.
+ auto op = builder_.create(UnknownLoc(), outputTypes, inputs, attributes);
+ for (int i = 0; i < node.output().size(); i++) {
+ frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
+ *(op.getODSResults(i).begin()));
}
}
@@ -398,8 +356,7 @@ private:
* c++ does not allow template specialization inside a class scope
* a specialized function is used
*/
- void
- ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
+ void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto
@@ -413,24 +370,20 @@ private:
int nOps = node.input().size();
if (nOps == 2)
- ImportNodeOneOut(
- node, nOps, nOut);
+ buildOperation(node, nOps, nOut);
else
- ImportNodeOneOut(node, nOps, nOut);
+ buildOperation(node, nOps, nOut);
}
/*!
* Special handle for MaxPool operations.
*/
- void ImportNodeMaxPool(
- onnx::NodeProto node, int nIn, int nOut) {
+ void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) {
int nOuts = node.output().size();
if (nOuts == 1) {
- ImportNodeOneOut(
- node, nIn, nOuts);
+ buildOperation(node, nIn, nOuts);
} else {
- ImportNodeMultipleOuts(
- node, nIn, nOuts);
+ buildOperation(node, nIn, nOuts);
}
}
@@ -441,23 +394,10 @@ private:
int nOuts = node.output().size();
if (nOuts == 1) {
// Test mode with one output.
- ImportNodeOneOut(node, nIn,
- nOuts);
+ buildOperation(node, nIn, nOuts);
} else {
// Training mode with four trailing optional outputs. Not handled yet.
- ImportNodeMultipleOuts(node, nIn, nOuts);
- }
- }
-
- /*!
- * Special handle for Gemm operations.
- */
- void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) {
- int nOps = node.input().size();
- if (nOps == 2) {
- ImportNodeOneOut(node, 2, nOut);
- } else {
- ImportNodeOneOut(node, nIn, nOut);
+ buildOperation(node, nIn, nOuts);
}
}
@@ -467,28 +407,14 @@ private:
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
- ImportNodeOneOut(node, 2, nOut);
+ buildOperation(node, 2, nOut);
} else {
- ImportNodeOneOut(node, nIn, nOut);
+ buildOperation(node, nIn, nOut);
}
}
void ImportNode(const onnx::NodeProto &node) {
- std::vector inputs;
- for (const auto &item : node.input()) {
- if (frontend_symbols_.ContainKey(legalize_name(item))) {
- inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
- }
- }
-
- std::vector outputTypes;
- for (auto item : node.output()) {
- outputTypes.push_back(
- mlir::UnrankedTensorType::get(builder_.getF32Type()));
- }
-
- std::vector attributes;
- llvm::StringRef OpName = node.op_type();
+ llvm::StringRef opName = node.op_type();
// the following code is generated by gen_doc.py
// refer to dialect/onnx/onnx.td for details
@@ -555,9 +481,11 @@ private:
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
}
- // import nodes in the graph
- auto node = graph.node();
- for (const auto &item : node) {
+ // Create a NoneTyped constant.
+ none_ =
+ builder_.create(UnknownLoc(), builder_.getUnitAttr());
+ // Import nodes in the graph.
+ for (const auto &item : graph.node()) {
ImportNode(item);
}
diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc
index c0b2ca6..41a910f 100644
--- a/src/builder/op_build_table.inc
+++ b/src/builder/op_build_table.inc
@@ -1,320 +1,319 @@
//********************************************************
-// Warning: Do not modify this file directly
-// This file is automatically generated via script
-// Details can be found in doc/readonnxdefs.md
+// This file is generated on UTC-02/24/2020, 06:29:01.
+// Do not modify this file directly.
+// This file is automatically generated via script.
+// Details can be found in doc/readonnxdefs.md .
//********************************************************
- if (OpName == "DUMMY") {
- }else if (OpName == "Abs") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Acos") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Acosh") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Add") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "And") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "ArgMax") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ArgMin") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Asin") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Asinh") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Atan") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Atanh") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "AveragePool") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "BatchNormalization") {
- ImportNodeBatchNormalization(node, 5, 5);
- }else if (OpName == "BitShift") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Cast") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Ceil") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Clip") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "Compress") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Concat") {
- ImportNodeOneOut(node, 1, 1, true, false);
- }else if (OpName == "ConcatFromSequence") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Constant") {
- ImportNodeOneOut(node, 0, 1);
- }else if (OpName == "ConstantOfShape") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Conv") {
- ImportNodeConv(node, 3, 1);
- }else if (OpName == "ConvInteger") {
- ImportNodeOneOut(node, 4, 1);
- }else if (OpName == "ConvTranspose") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "Cos") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Cosh") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "CumSum") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "DepthToSpace") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "DequantizeLinear") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "Det") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Div") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Dropout") {
- ImportNodeMultipleOuts(node, 1, 2);
- }else if (OpName == "DynamicQuantizeLinear") {
- ImportNodeMultipleOuts(node, 1, 3);
- }else if (OpName == "Elu") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Equal") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Erf") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Exp") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Expand") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "EyeLike") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Flatten") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Floor") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "GRU") {
- ImportNodeMultipleOuts(node, 6, 2);
- }else if (OpName == "Gather") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "GatherElements") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "GatherND") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Gemm") {
- ImportNodeGemm(node, 3, 1);
- }else if (OpName == "GlobalAveragePool") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "GlobalLpPool") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "GlobalMaxPool") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Greater") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "HardSigmoid") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Hardmax") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Identity") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "If") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "InstanceNormalization") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "IsInf") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "IsNaN") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "LRN") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "LSTM") {
- ImportNodeMultipleOuts(node, 8, 3);
- }else if (OpName == "LeakyRelu") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Less") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Log") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "LogSoftmax") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Loop") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "LpNormalization") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "LpPool") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "MatMul") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "MatMulInteger") {
- ImportNodeOneOut(node, 4, 1);
- }else if (OpName == "Max") {
- ImportNodeOneOut(node, 1, 1, true, false);
- }else if (OpName == "MaxPool") {
- ImportNodeMaxPool(node, 1, 2);
- }else if (OpName == "MaxRoiPool") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "MaxUnpool") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "Mean") {
- ImportNodeOneOut(node, 1, 1, true, false);
- }else if (OpName == "MeanVarianceNormalization") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Min") {
- ImportNodeOneOut(node, 1, 1, true, false);
- }else if (OpName == "Mod") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Mul") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Multinomial") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Neg") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "NonMaxSuppression") {
- ImportNodeOneOut(node, 5, 1);
- }else if (OpName == "NonZero") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Not") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "OneHot") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "Or") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "PRelu") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Pad") {
- ImportNodePad(node, 3, 1);
- }else if (OpName == "Pow") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "QLinearConv") {
- ImportNodeOneOut(node, 9, 1);
- }else if (OpName == "QLinearMatMul") {
- ImportNodeOneOut(node, 8, 1);
- }else if (OpName == "QuantizeLinear") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "RNN") {
- ImportNodeMultipleOuts(node, 6, 2);
- }else if (OpName == "RandomNormal") {
- ImportNodeOneOut(node, 0, 1);
- }else if (OpName == "RandomNormalLike") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "RandomUniform") {
- ImportNodeOneOut(node, 0, 1);
- }else if (OpName == "RandomUniformLike") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Range") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "Reciprocal") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ReduceL1") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ReduceL2") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ReduceLogSum") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ReduceLogSumExp") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ReduceMax") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ReduceMean") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ReduceMin") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ReduceProd") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ReduceSum") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ReduceSumSquare") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Relu") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Reshape") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Resize") {
- ImportNodeOneOut(node, 4, 1);
- }else if (OpName == "ReverseSequence") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "RoiAlign") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "Round") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Scan") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Scatter") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "ScatterElements") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "ScatterND") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "Selu") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "SequenceAt") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "SequenceConstruct") {
- ImportNodeOneOut(node, 1, 1, true, false);
- }else if (OpName == "SequenceEmpty") {
- ImportNodeOneOut(node, 0, 1);
- }else if (OpName == "SequenceErase") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "SequenceInsert") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "SequenceLength") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Shape") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Shrink") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Sigmoid") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Sign") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Sin") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Sinh") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Size") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Slice") {
- ImportNodeOneOut(node, 5, 1);
- }else if (OpName == "Softmax") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Softplus") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Softsign") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "SpaceToDepth") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Split") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "SplitToSequence") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Sqrt") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Squeeze") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "StringNormalizer") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Sub") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Sum") {
- ImportNodeOneOut(node, 1, 1, true, false);
- }else if (OpName == "Tan") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Tanh") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "TfIdfVectorizer") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "ThresholdedRelu") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Tile") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "TopK") {
- ImportNodeMultipleOuts(node, 2, 2);
- }else if (OpName == "Transpose") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Unique") {
- ImportNodeMultipleOuts(node, 1, 4);
- }else if (OpName == "Unsqueeze") {
- ImportNodeOneOut(node, 1, 1);
- }else if (OpName == "Upsample") {
- ImportNodeOneOut(node, 2, 1);
- }else if (OpName == "Where") {
- ImportNodeOneOut(node, 3, 1);
- }else if (OpName == "Xor") {
- ImportNodeOneOut(node, 2, 1);
- }
\ No newline at end of file
+if (opName == "Abs")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Acos")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Acosh")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Add")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "And")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "ArgMax")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "ArgMin")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Asin")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Asinh")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Atan")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Atanh")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "AveragePool")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "BatchNormalization")
+ return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
+if (opName == "BitShift")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "Cast")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Ceil")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Clip")
+ return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
+if (opName == "Compress")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "Concat")
+ return buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
+if (opName == "ConcatFromSequence")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Constant")
+ return buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
+if (opName == "ConstantOfShape")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Conv")
+ return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
+if (opName == "ConvInteger")
+ return buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
+if (opName == "ConvTranspose")
+ return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
+if (opName == "Cos")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Cosh")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "CumSum")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "DepthToSpace")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "DequantizeLinear")
+ return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
+if (opName == "Det")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Div")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "Dropout")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
+if (opName == "DynamicQuantizeLinear")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
+if (opName == "Elu")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Equal")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "Erf")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Exp")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Expand")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "EyeLike")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Flatten")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Floor")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "GRU")
+ return buildOperation(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
+if (opName == "Gather")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "GatherElements")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "GatherND")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "Gemm")
+ return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
+if (opName == "GlobalAveragePool")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "GlobalLpPool")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "GlobalMaxPool")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Greater")
+ return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
+if (opName == "HardSigmoid")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Hardmax")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "Identity")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "If")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
+if (opName == "InstanceNormalization")
+ return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
+if (opName == "IsInf")
+ return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
+if (opName == "IsNaN")
+ return buildOperation