Merge remote-tracking branch 'upstream/master' into shapeinference-pad

This commit is contained in:
chentong 2020-02-25 15:54:18 -05:00
commit 4079ee1f26
35 changed files with 4683 additions and 4636 deletions

View File

@ -38,7 +38,7 @@ jobs:
- run:
name: Run End-To-End Tests
command: |
sudo pip install -q onnx
sudo pip install -q -e ./ONNF/third_party/onnx
cd ONNF/build
cmake --build . --target run-onnx-backend-test
- run:

View File

@ -1,2 +1,3 @@
BasedOnStyle: LLVM
AlwaysBreakTemplateDeclarations: Yes
AlignAfterOpenBracket: DontAlign

142
.gitignore vendored
View File

@ -30,3 +30,145 @@
*.exe
*.out
*.app
# Filesystem
.DS_Store
# The following .gitignore content is taken from
# https://github.com/github/gitignore/blob/master/Python.gitignore
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/

View File

@ -327,10 +327,10 @@ ONNX BatchNormalization operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `out_mean`: memref of any type values or tensor of any type values
1. `out_var`: memref of any type values or tensor of any type values
1. `saved_mean`: memref of any type values or tensor of any type values
1. `saved_var`: memref of any type values or tensor of any type values
1. `out_mean`: memref of any type values or tensor of any type values or none type
1. `out_var`: memref of any type values or tensor of any type values or none type
1. `saved_mean`: memref of any type values or tensor of any type values or none type
1. `saved_var`: memref of any type values or tensor of any type values or none type
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
ONNX BatchNormalization operation in test mode
@ -375,12 +375,12 @@ ONNX BitShift operation
"Bitwise shift operator performs element-wise operation. For each input element, if the"
" attribute "direction" is "RIGHT", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute "direction""
" is "LEFT", bits of binary representation moves toward the left side, which results the"
" attribute \"direction\" is \"RIGHT\", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute \"direction\""
" is \"LEFT\", bits of binary representation moves toward the left side, which results the"
" increase of its actual value. The input X is the tensor to be shifted and another input"
" Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with"
" Y specifies the amounts of shifting. For example, if \"direction\" is \"Right\", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If \"direction\" is \"LEFT\" with"
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
" "
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
@ -413,15 +413,15 @@ ONNX Cast operation
"the converted type. The 'to' argument must be one of the data types specified"
"in the 'DataType' enum field in the TensorProto message."
""
"Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations"
"(e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may"
"Casting from string tensor in plain (e.g., \"3.14\" and \"1000\") and scientific numeric representations"
"(e.g., \"1e-5\" and \"1E8\") to float types is supported. For example, converting string \"100.5\" to an integer may"
"result 100. There are some string literals reserved for special floating-point values;"
""+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as "314.15926") would be used. "
"Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior."
"\"+INF\" (and \"INF\"), \"-INF\", and \"NaN\" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match \"+INF\" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to \"INF\" and \"NaN\". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as \"314.15926\") would be used. "
"Converting non-numerical-literal string such as \"Hello World!\" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as \"2.718\", to INT is an undefined behavior."
""
"Conversion from a numerical type to any numerical type is always allowed."
"User must be aware of precision loss and value change caused by range difference between two types."
@ -476,8 +476,8 @@ ONNX Clip operation
#### Operands:
1. `input`: memref of any type values or tensor of any type values
1. `min`: memref of any type values or tensor of any type values
1. `max`: memref of any type values or tensor of any type values
1. `min`: memref of any type values or tensor of any type values or none type
1. `max`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -618,8 +618,8 @@ ONNX ConvInteger operation
1. `x`: memref of any type values or tensor of any type values
1. `w`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values
1. `w_zero_point`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values or none type
1. `w_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -678,7 +678,7 @@ ONNX Conv operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -720,7 +720,7 @@ ONNX ConvTranspose operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -884,7 +884,7 @@ ONNX DequantizeLinear operation
1. `x`: memref of any type values or tensor of any type values
1. `x_scale`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -964,7 +964,7 @@ ONNX Dropout operation
#### Results:
1. `output`: memref of any type values or tensor of any type values
1. `mask`: memref of any type values or tensor of any type values
1. `mask`: memref of any type values or tensor of any type values or none type
### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp)
ONNX DynamicQuantizeLinear operation
@ -1297,9 +1297,9 @@ ONNX GRU operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `sequence_lens`: memref of any type values or tensor of any type values
1. `initial_h`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -1315,8 +1315,8 @@ ONNX GRU operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Y_h`: memref of any type values or tensor of any type values
1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.GatherElements (ONNXGatherElementsOp)
ONNX GatherElements operation
@ -1558,33 +1558,6 @@ ONNX Gather operation
1. `output`: memref of any type values or tensor of any type values
### onnx.GemmNoBias (ONNXGemmNoBiasOp)
ONNX general matrix multiply operation without bias.
#### Description:
The "onnx.Gemm" generic matrix multiplication without bias.
#### Operands:
1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
#### Attributes:
| Attribute | MLIR Type | Description |
| :-------: | :-------: | ----------- |
| `alpha` | `FloatAttr` | 32-bit float attribute attribute |
| `beta` | `FloatAttr` | 32-bit float attribute attribute |
| `transA` | `IntegerAttr` | 64-bit integer attribute attribute |
| `transB` | `IntegerAttr` | 64-bit integer attribute attribute |
#### Results:
1. `o_Y`: memref of any type values or tensor of any type values
### onnx.Gemm (ONNXGemmOp)
ONNX Gemm operation
@ -1609,7 +1582,7 @@ ONNX Gemm operation
1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `C`: memref of any type values or tensor of any type values
1. `C`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2013,11 +1986,11 @@ ONNX LSTM operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `sequence_lens`: memref of any type values or tensor of any type values
1. `initial_h`: memref of any type values or tensor of any type values
1. `initial_c`: memref of any type values or tensor of any type values
1. `P`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values or none type
1. `initial_c`: memref of any type values or tensor of any type values or none type
1. `P`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2033,9 +2006,9 @@ ONNX LSTM operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Y_h`: memref of any type values or tensor of any type values
1. `Y_c`: memref of any type values or tensor of any type values
1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values or none type
1. `Y_c`: memref of any type values or tensor of any type values or none type
### onnx.LeakyRelu (ONNXLeakyReluOp)
ONNX LeakyRelu operation
@ -2160,24 +2133,24 @@ ONNX Loop operation
""
" Operator inputs defined as (max_trip_count, condition_var)."
""
" input ("", ""):"
" input (\"\", \"\"):"
" for (int i=0; ; ++i) {"
" cond = ... // Note this value is ignored, but is required in the body"
" }"
""
" input ("", cond) // Note this is analogous to a while loop"
" input (\"\", cond) // Note this is analogous to a while loop"
" bool cond = ...;"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
" input ("", 1) // Note this is analogous to a do-while loop"
" input (\"\", 1) // Note this is analogous to a do-while loop"
" bool cond = true"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
" input (trip_count, "") // Note this is analogous to a for loop"
" input (trip_count, \"\") // Note this is analogous to a for loop"
" int trip_count = ..."
" for (int i=0; i < trip_count; ++i) {"
" cond = ...; // ignored"
@ -2203,15 +2176,15 @@ ONNX Loop operation
" }"
""
" graph body-net ("
" %i[INT32, scalar] // iteration number"
" %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used"
" %b_in[INT32, scalar] // incoming value of loop-carried-dependency b"
" %i[INT32, scalar]"
" %keepgoing[BOOL, scalar]"
" %b[INT32, scalar]"
" ) {"
" %my_local = Add(%a, %b_in)"
" %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b"
" %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition"
" %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated"
" return %keepgoing_out, %b_out, %user_defined_val"
" %my_local = Add(%a, %b)"
" %b_out = Sub(%a, %b)"
" %keepgoing_out = Greater(%my_local, %b_out)"
" %user_defined_vals = Add(%b, %b)"
" return %keepgoing_out, %b_out, %user_defined_vals"
" }"
""
"*Sample equivalent C code*"
@ -2226,51 +2199,31 @@ ONNX Loop operation
" const int max_trip_count = 10; // Analogous to input M"
" int user_defined_vals[]; // Imagine this is resizable"
" /* End implicitly-defined code */"
" /* initialize loop-carried variables and scan-output variables */"
" bool keepgoing_out = keepgoing"
" int b_out = b"
""
" for (int i=0; i < max_trip_count && keepgoing_out; ++i) {"
" /* Implicitly-defined code: bind actual parameter values"
" to formal parameter variables of loop-body */"
" bool keepgoing_in = keepgoing_out; "
" bool b_in = b_out;"
""
" for (int i=0; i < max_trip_count && keepgoing; ++i) {"
" /* User-defined code (loop body) */"
" int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine"
" b_out = a - b_in;"
" keepgoing_out = my_local > b_out; "
" user_defined_val = b_in + b_in; // b_in and b_out are different variables"
" int my_local = a + b; // Reading values in the enclosing scope is fine"
" b = a - b; // writes fine if we specify b as a loop-carried dependency"
" keepgoing = my_local > b; // keepgoing is a loop-carried dependency"
" user_defined_vals[i] = b + b;"
" /* End user-defined code */"
""
" /* Implicitly defined-code */"
" user_defined_vals[i] = user_defined_val // accumulate scan-output values"
" }"
" // int t = my_local; // Can't do this. my_local is not accessible here."
" // my_local = 123; // Can't do this. my_local was defined in the the body"
""
" // The values below are bound to the output variables of the loop and therefore accessible"
" // b_out; user_defined_vals; keepgoing_out;"
" // These below values are live-out from the loop and therefore accessible"
" b_out; user_defined_vals; keepgoing_out;"
" }"
""
"There are several things of note in this code snippet:"
""
"1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can"
"1) Values from the enclosing scope (i.e. variable a here) are in scope and can"
" be referenced in the inputs of the loop."
"2) Any values computed in the loop body that needs to be used in a subsequent"
" iteration or after the loop are modelled using a pair of variables in the loop-body,"
" consisting of an input variable (eg., b_in) and an output variable (eg., b_out)."
" These are referred to as loop-carried dependences. The loop operation node"
" supplies the input value of the input variable for the first iteration, and"
" returns the output value of the output variable produced by the final"
" iteration."
"3) Scan_output variables are used to implicitly concatenate values computed across"
" all the iterations. In the above example, the value of user_defined_val computed"
" over all iterations are concatenated and returned as the value of user_defined_vals"
" after the loop."
"4) Values created in the body cannot be accessed in the enclosing scope,"
" except using the mechanism described above."
"2) Any variables which you wish to make available in the enclosing scope (i.e."
" the variables b and keepgoing) must be declared as either loop-carried"
" dependencies (both at the op inputs and output and at the body net input and"
" output) or scan_outputs."
"3) Values created in the body cannot be accessed in the enclosing scope."
""
"Note that the semantics of this op support "diagonal" or "wavefront" execution."
"Note that the semantics of this op support \"diagonal\" or \"wavefront\" execution."
"(See Step 3 here for an example:"
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
"Frontends should emit multi-layer RNNs as a series of While operators (with"
@ -2280,8 +2233,8 @@ ONNX Loop operation
#### Operands:
1. `M`: memref of any type values or tensor of any type values
1. `cond`: memref of any type values or tensor of any type values
1. `M`: memref of any type values or tensor of any type values or none type
1. `cond`: memref of any type values or tensor of any type values or none type
1. `v_initial`: memref of any type values or tensor of any type values
#### Attributes:
@ -2360,8 +2313,8 @@ ONNX MatMulInteger operation
1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `a_zero_point`: memref of any type values or tensor of any type values
1. `b_zero_point`: memref of any type values or tensor of any type values
1. `a_zero_point`: memref of any type values or tensor of any type values or none type
1. `b_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2466,7 +2419,7 @@ ONNX MaxPool operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Indices`: memref of any type values or tensor of any type values
1. `Indices`: memref of any type values or tensor of any type values or none type
### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp)
ONNX MaxPool operation with a single output.
@ -2552,7 +2505,7 @@ ONNX MaxUnpool operation
1. `X`: memref of any type values or tensor of any type values
1. `I`: memref of any type values or tensor of any type values
1. `output_shape`: memref of any type values or tensor of any type values
1. `output_shape`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2752,9 +2705,9 @@ ONNX NonMaxSuppression operation
1. `boxes`: memref of any type values or tensor of any type values
1. `scores`: memref of any type values or tensor of any type values
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values
1. `iou_threshold`: memref of any type values or tensor of any type values
1. `score_threshold`: memref of any type values or tensor of any type values
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values or none type
1. `iou_threshold`: memref of any type values or tensor of any type values or none type
1. `score_threshold`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3067,7 +3020,7 @@ ONNX Pad operation
1. `data`: memref of any type values or tensor of any type values
1. `pads`: memref of any type values or tensor of any type values
1. `constant_value`: memref of any type values or tensor of any type values
1. `constant_value`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3124,7 +3077,7 @@ ONNX QLinearConv operation
1. `w_zero_point`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3188,7 +3141,7 @@ ONNX QuantizeLinear operation
1. `x`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3270,9 +3223,9 @@ ONNX RNN operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `sequence_lens`: memref of any type values or tensor of any type values
1. `initial_h`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3287,8 +3240,8 @@ ONNX RNN operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Y_h`: memref of any type values or tensor of any type values
1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.RandomNormalLike (ONNXRandomNormalLikeOp)
ONNX RandomNormalLike operation
@ -3813,14 +3766,14 @@ ONNX Resize operation
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
"Each dimension value of the output tensor is:"
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified."
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified."
#### Operands:
1. `X`: memref of any type values or tensor of any type values
1. `roi`: memref of any type values or tensor of any type values
1. `scales`: memref of any type values or tensor of any type values
1. `sizes`: memref of any type values or tensor of any type values
1. `sizes`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4438,7 +4391,7 @@ ONNX SequenceErase operation
#### Operands:
1. `input_sequence`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4463,7 +4416,7 @@ ONNX SequenceInsert operation
1. `input_sequence`: memref of any type values or tensor of any type values
1. `tensor`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4680,8 +4633,8 @@ ONNX Slice operation
1. `data`: memref of any type values or tensor of any type values
1. `starts`: memref of any type values or tensor of any type values
1. `ends`: memref of any type values or tensor of any type values
1. `axes`: memref of any type values or tensor of any type values
1. `steps`: memref of any type values or tensor of any type values
1. `axes`: memref of any type values or tensor of any type values or none type
1. `steps`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4834,7 +4787,7 @@ ONNX SplitToSequence operation
#### Operands:
1. `input`: memref of any type values or tensor of any type values
1. `split`: memref of any type values or tensor of any type values
1. `split`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4902,9 +4855,9 @@ ONNX StringNormalizer operation
"StringNormalization performs string operations for basic cleaning."
"This operator has only one input (denoted by X) and only one output"
"(denoted by Y). This operator first examines the elements in the X,"
"and removes elements specified in "stopwords" attribute."
"and removes elements specified in \"stopwords\" attribute."
"After removing stop words, the intermediate result can be further lowercased,"
"uppercased, or just returned depending the "case_change_action" attribute."
"uppercased, or just returned depending the \"case_change_action\" attribute."
"This operator only accepts [C]- and [1, C]-tensor."
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
"if input shape is [C] and shape [1, 1] if input shape is [1, C]."
@ -5034,8 +4987,8 @@ ONNX TfIdfVectorizer operation
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
"Note that we may consider all skips up to S when generating the n-grams."
""
"The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF","
"The examples used above are true if mode is \"TF\". If mode is \"IDF\", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is \"TFIDF\","
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
""
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
@ -5123,9 +5076,9 @@ ONNX TopK operation
" contains the indices of the top k elements (original indices from the input"
" tensor)."
""
"If "largest" is 1 (the default value) then the k largest elements are returned."
"If "sorted" is 1 (the default value) then the resulting k elements will be sorted."
"If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined."
"If \"largest\" is 1 (the default value) then the k largest elements are returned."
"If \"sorted\" is 1 (the default value) then the resulting k elements will be sorted."
"If \"sorted\" is 0, order of returned 'Values' and 'Indices' are undefined."
""
"Given two equivalent values, this operator uses the indices along the axis as"
" a tiebreaker. That is, the element with the lower index will appear first."
@ -5184,7 +5137,7 @@ ONNX Unique operation
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
"The first output tensor 'Y' contains all unique values or subtensors of the input. "
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. \". "
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
""
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
@ -5268,9 +5221,9 @@ ONNX Unique operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `indices`: memref of any type values or tensor of any type values
1. `inverse_indices`: memref of any type values or tensor of any type values
1. `counts`: memref of any type values or tensor of any type values
1. `indices`: memref of any type values or tensor of any type values or none type
1. `inverse_indices`: memref of any type values or tensor of any type values or none type
1. `counts`: memref of any type values or tensor of any type values or none type
### onnx.Unsqueeze (ONNXUnsqueezeOp)
ONNX Unsqueeze operation

File diff suppressed because it is too large Load Diff

View File

@ -62,7 +62,21 @@ target_include_directories(onnf_shape_inference
target_link_libraries(onnf_shape_inference ${MLIRLibs})
add_dependencies(onnf_shape_inference gen_krnl_ops)
add_library(onnf_lower_frontend conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
add_library(onnf_lower_frontend
conversion/onnx_to_krnl/onnx_to_krnl_common.cpp
conversion/onnx_to_krnl/onnx_to_krnl_common.hpp
conversion/onnx_to_krnl/math/elementwise.cpp
conversion/onnx_to_krnl/math/gemm.cpp
conversion/onnx_to_krnl/math/matmul.cpp
conversion/onnx_to_krnl/math/reduction.cpp
conversion/onnx_to_krnl/math/softmax.cpp
conversion/onnx_to_krnl/nn/conv.cpp
conversion/onnx_to_krnl/nn/normalization.cpp
conversion/onnx_to_krnl/tensor/identity.cpp
conversion/onnx_to_krnl/tensor/reshape.cpp
conversion/onnx_to_krnl/tensor/transpose.cpp
conversion/onnx_to_krnl/tensor/unsqueeze.cpp
conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
target_include_directories(onnf_lower_frontend
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT})

View File

@ -121,6 +121,7 @@ private:
mlir::MLIRContext &context_;
mlir::ModuleOp module_;
mlir::OpBuilder builder_;
mlir::Value none_;
// mapping between string name and symbol
OnnxOnnfSymbolMapping frontend_symbols_;
@ -188,8 +189,9 @@ private:
}
}
mlir::Type elementType =
convertONNXTypeToMLIRType(input.type().tensor_type().elem_type());
auto elementOnnxType =
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
arg_types.emplace_back(
mlir::RankedTensorType::get(tensor_dims, elementType));
@ -287,8 +289,8 @@ private:
}
}
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute>
ImportNodeAttributes(const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute> attributes;
for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i);
@ -317,21 +319,11 @@ private:
}
}
// if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
// combined with 'if constexpr' the issue is the type of the output is
// different. alternative way to use variadic output for all the op
/*!
* Important onnx node which generates only one output
* @param node onnx node
* @param nIn number of expected inputs
* @param nOut number of expected outputs
* @param attrs list of desription for attributes with format {name, type,
* default}
*/
template <typename T>
void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
bool variadicIn = false, bool variadicOut = false) {
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
int expectedNumResults = -1) {
bool variadicIn = expectedNumOperands == -1;
bool variadicOut = expectedNumResults == -1;
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
@ -339,6 +331,10 @@ private:
}
}
if (!variadicIn)
for (auto i = inputs.size(); i < expectedNumOperands; i++)
inputs.emplace_back(none_);
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
@ -347,49 +343,11 @@ private:
auto attributes = ImportNodeAttributes(node);
llvm::StringRef OpName = node.op_type();
if ((variadicIn || nIn == inputs.size()) &&
(variadicOut || nOut == outputTypes.size())) {
auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
op.getResult());
} else {
ImportNodeGeneric(node);
}
}
template <typename T>
void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut,
bool variadicIn = false,
bool variadicOut = false) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
auto attributes = ImportNodeAttributes(node);
llvm::StringRef OpName = node.op_type();
if ((variadicIn || nIn == inputs.size()) &&
(variadicOut || nOut == outputTypes.size())) {
auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
// TODO: Handle optional inputs.
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
op.getResult(i));
}
} else {
ImportNodeGeneric(node);
*(op.getODSResults(i).begin()));
}
}
@ -398,8 +356,7 @@ private:
* c++ does not allow template specialization inside a class scope
* a specialized function is used
*/
void
ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto
@ -413,24 +370,20 @@ private:
int nOps = node.input().size();
if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(
node, nOps, nOut);
buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut);
else
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut);
buildOperation<mlir::ONNXConvOp>(node, nOps, nOut);
}
/*!
* Special handle for MaxPool operations.
*/
void ImportNodeMaxPool(
onnx::NodeProto node, int nIn, int nOut) {
void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) {
int nOuts = node.output().size();
if (nOuts == 1) {
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(
node, nIn, nOuts);
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts);
} else {
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(
node, nIn, nOuts);
buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts);
}
}
@ -441,23 +394,10 @@ private:
int nOuts = node.output().size();
if (nOuts == 1) {
// Test mode with one output.
ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn,
nOuts);
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts);
} else {
// Training mode with four trailing optional outputs. Not handled yet.
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
}
}
/*!
* Special handle for Gemm operations.
*/
void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXGemmNoBiasOp>(node, 2, nOut);
} else {
ImportNodeOneOut<mlir::ONNXGemmOp>(node, nIn, nOut);
buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
}
}
@ -467,28 +407,14 @@ private:
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
buildOperation<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
} else {
ImportNodeOneOut<mlir::ONNXPadOp>(node, nIn, nOut);
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
}
}
void ImportNode(const onnx::NodeProto &node) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
std::vector<mlir::NamedAttribute> attributes;
llvm::StringRef OpName = node.op_type();
llvm::StringRef opName = node.op_type();
// the following code is generated by gen_doc.py
// refer to dialect/onnx/onnx.td for details
@ -555,9 +481,11 @@ private:
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
}
// import nodes in the graph
auto node = graph.node();
for (const auto &item : node) {
// Create a NoneTyped constant.
none_ =
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
// Import nodes in the graph.
for (const auto &item : graph.node()) {
ImportNode(item);
}

View File

@ -1,320 +1,319 @@
//********************************************************
// Warning: Do not modify this file directly
// This file is automatically generated via script
// Details can be found in doc/readonnxdefs.md
// This file is generated on UTC-02/24/2020, 06:29:01.
// Do not modify this file directly.
// This file is automatically generated via script.
// Details can be found in doc/readonnxdefs.md .
//********************************************************
if (OpName == "DUMMY") {
}else if (OpName == "Abs") {
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1);
}else if (OpName == "Acos") {
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1);
}else if (OpName == "Acosh") {
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1);
}else if (OpName == "Add") {
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1);
}else if (OpName == "And") {
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1);
}else if (OpName == "ArgMax") {
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1);
}else if (OpName == "ArgMin") {
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1);
}else if (OpName == "Asin") {
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1);
}else if (OpName == "Asinh") {
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1);
}else if (OpName == "Atan") {
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1);
}else if (OpName == "Atanh") {
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1);
}else if (OpName == "AveragePool") {
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1);
}else if (OpName == "BatchNormalization") {
ImportNodeBatchNormalization(node, 5, 5);
}else if (OpName == "BitShift") {
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1);
}else if (OpName == "Cast") {
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1);
}else if (OpName == "Ceil") {
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1);
}else if (OpName == "Clip") {
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1);
}else if (OpName == "Compress") {
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1);
}else if (OpName == "Concat") {
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, true, false);
}else if (OpName == "ConcatFromSequence") {
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1);
}else if (OpName == "Constant") {
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1);
}else if (OpName == "ConstantOfShape") {
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1);
}else if (OpName == "Conv") {
ImportNodeConv(node, 3, 1);
}else if (OpName == "ConvInteger") {
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1);
}else if (OpName == "ConvTranspose") {
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1);
}else if (OpName == "Cos") {
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1);
}else if (OpName == "Cosh") {
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1);
}else if (OpName == "CumSum") {
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1);
}else if (OpName == "DepthToSpace") {
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1);
}else if (OpName == "DequantizeLinear") {
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1);
}else if (OpName == "Det") {
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1);
}else if (OpName == "Div") {
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1);
}else if (OpName == "Dropout") {
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2);
}else if (OpName == "DynamicQuantizeLinear") {
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3);
}else if (OpName == "Elu") {
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1);
}else if (OpName == "Equal") {
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1);
}else if (OpName == "Erf") {
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1);
}else if (OpName == "Exp") {
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1);
}else if (OpName == "Expand") {
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1);
}else if (OpName == "EyeLike") {
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1);
}else if (OpName == "Flatten") {
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1);
}else if (OpName == "Floor") {
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1);
}else if (OpName == "GRU") {
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2);
}else if (OpName == "Gather") {
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1);
}else if (OpName == "GatherElements") {
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1);
}else if (OpName == "GatherND") {
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1);
}else if (OpName == "Gemm") {
ImportNodeGemm(node, 3, 1);
}else if (OpName == "GlobalAveragePool") {
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1);
}else if (OpName == "GlobalLpPool") {
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1);
}else if (OpName == "GlobalMaxPool") {
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1);
}else if (OpName == "Greater") {
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1);
}else if (OpName == "HardSigmoid") {
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1);
}else if (OpName == "Hardmax") {
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1);
}else if (OpName == "Identity") {
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1);
}else if (OpName == "If") {
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1);
}else if (OpName == "InstanceNormalization") {
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1);
}else if (OpName == "IsInf") {
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1);
}else if (OpName == "IsNaN") {
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1);
}else if (OpName == "LRN") {
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1);
}else if (OpName == "LSTM") {
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3);
}else if (OpName == "LeakyRelu") {
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1);
}else if (OpName == "Less") {
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1);
}else if (OpName == "Log") {
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1);
}else if (OpName == "LogSoftmax") {
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1);
}else if (OpName == "Loop") {
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1);
}else if (OpName == "LpNormalization") {
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1);
}else if (OpName == "LpPool") {
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1);
}else if (OpName == "MatMul") {
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1);
}else if (OpName == "MatMulInteger") {
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1);
}else if (OpName == "Max") {
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, true, false);
}else if (OpName == "MaxPool") {
ImportNodeMaxPool(node, 1, 2);
}else if (OpName == "MaxRoiPool") {
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1);
}else if (OpName == "MaxUnpool") {
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1);
}else if (OpName == "Mean") {
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, true, false);
}else if (OpName == "MeanVarianceNormalization") {
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1);
}else if (OpName == "Min") {
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, true, false);
}else if (OpName == "Mod") {
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1);
}else if (OpName == "Mul") {
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1);
}else if (OpName == "Multinomial") {
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1);
}else if (OpName == "Neg") {
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1);
}else if (OpName == "NonMaxSuppression") {
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1);
}else if (OpName == "NonZero") {
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1);
}else if (OpName == "Not") {
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1);
}else if (OpName == "OneHot") {
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1);
}else if (OpName == "Or") {
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1);
}else if (OpName == "PRelu") {
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1);
}else if (OpName == "Pad") {
ImportNodePad(node, 3, 1);
}else if (OpName == "Pow") {
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1);
}else if (OpName == "QLinearConv") {
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1);
}else if (OpName == "QLinearMatMul") {
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1);
}else if (OpName == "QuantizeLinear") {
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1);
}else if (OpName == "RNN") {
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2);
}else if (OpName == "RandomNormal") {
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1);
}else if (OpName == "RandomNormalLike") {
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1);
}else if (OpName == "RandomUniform") {
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1);
}else if (OpName == "RandomUniformLike") {
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1);
}else if (OpName == "Range") {
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1);
}else if (OpName == "Reciprocal") {
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1);
}else if (OpName == "ReduceL1") {
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1);
}else if (OpName == "ReduceL2") {
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1);
}else if (OpName == "ReduceLogSum") {
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1);
}else if (OpName == "ReduceLogSumExp") {
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1);
}else if (OpName == "ReduceMax") {
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1);
}else if (OpName == "ReduceMean") {
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1);
}else if (OpName == "ReduceMin") {
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1);
}else if (OpName == "ReduceProd") {
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1);
}else if (OpName == "ReduceSum") {
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1);
}else if (OpName == "ReduceSumSquare") {
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1);
}else if (OpName == "Relu") {
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1);
}else if (OpName == "Reshape") {
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1);
}else if (OpName == "Resize") {
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1);
}else if (OpName == "ReverseSequence") {
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1);
}else if (OpName == "RoiAlign") {
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1);
}else if (OpName == "Round") {
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1);
}else if (OpName == "Scan") {
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1);
}else if (OpName == "Scatter") {
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1);
}else if (OpName == "ScatterElements") {
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1);
}else if (OpName == "ScatterND") {
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1);
}else if (OpName == "Selu") {
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1);
}else if (OpName == "SequenceAt") {
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1);
}else if (OpName == "SequenceConstruct") {
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, true, false);
}else if (OpName == "SequenceEmpty") {
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1);
}else if (OpName == "SequenceErase") {
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1);
}else if (OpName == "SequenceInsert") {
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1);
}else if (OpName == "SequenceLength") {
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1);
}else if (OpName == "Shape") {
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1);
}else if (OpName == "Shrink") {
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1);
}else if (OpName == "Sigmoid") {
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1);
}else if (OpName == "Sign") {
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1);
}else if (OpName == "Sin") {
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1);
}else if (OpName == "Sinh") {
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1);
}else if (OpName == "Size") {
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1);
}else if (OpName == "Slice") {
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1);
}else if (OpName == "Softmax") {
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1);
}else if (OpName == "Softplus") {
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1);
}else if (OpName == "Softsign") {
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1);
}else if (OpName == "SpaceToDepth") {
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1);
}else if (OpName == "Split") {
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1);
}else if (OpName == "SplitToSequence") {
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1);
}else if (OpName == "Sqrt") {
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1);
}else if (OpName == "Squeeze") {
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1);
}else if (OpName == "StringNormalizer") {
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1);
}else if (OpName == "Sub") {
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1);
}else if (OpName == "Sum") {
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, true, false);
}else if (OpName == "Tan") {
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1);
}else if (OpName == "Tanh") {
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1);
}else if (OpName == "TfIdfVectorizer") {
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1);
}else if (OpName == "ThresholdedRelu") {
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1);
}else if (OpName == "Tile") {
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1);
}else if (OpName == "TopK") {
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2);
}else if (OpName == "Transpose") {
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1);
}else if (OpName == "Unique") {
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4);
}else if (OpName == "Unsqueeze") {
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1);
}else if (OpName == "Upsample") {
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1);
}else if (OpName == "Where") {
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1);
}else if (OpName == "Xor") {
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1);
}
if (opName == "Abs")
return buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Acos")
return buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Acosh")
return buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Add")
return buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "And")
return buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "ArgMax")
return buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ArgMin")
return buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Asin")
return buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Asinh")
return buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Atan")
return buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Atanh")
return buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "AveragePool")
return buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "BatchNormalization")
return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
if (opName == "BitShift")
return buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Cast")
return buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Ceil")
return buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Clip")
return buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Compress")
return buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Concat")
return buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "ConcatFromSequence")
return buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Constant")
return buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "ConstantOfShape")
return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Conv")
return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ConvInteger")
return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "ConvTranspose")
return buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Cos")
return buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Cosh")
return buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "CumSum")
return buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "DepthToSpace")
return buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "DequantizeLinear")
return buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Det")
return buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Div")
return buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Dropout")
return buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
if (opName == "DynamicQuantizeLinear")
return buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
if (opName == "Elu")
return buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Equal")
return buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Erf")
return buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Exp")
return buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Expand")
return buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "EyeLike")
return buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Flatten")
return buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Floor")
return buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GRU")
return buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
if (opName == "Gather")
return buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "GatherElements")
return buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "GatherND")
return buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Gemm")
return buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "GlobalAveragePool")
return buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GlobalLpPool")
return buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GlobalMaxPool")
return buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Greater")
return buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "HardSigmoid")
return buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Hardmax")
return buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Identity")
return buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "If")
return buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
if (opName == "InstanceNormalization")
return buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "IsInf")
return buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "IsNaN")
return buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LRN")
return buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LSTM")
return buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
if (opName == "LeakyRelu")
return buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Less")
return buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Log")
return buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LogSoftmax")
return buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Loop")
return buildOperation<mlir::ONNXLoopOp>(node);
if (opName == "LpNormalization")
return buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LpPool")
return buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "MatMul")
return buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "MatMulInteger")
return buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "Max")
return buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "MaxPool")
return ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
if (opName == "MaxRoiPool")
return buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "MaxUnpool")
return buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Mean")
return buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "MeanVarianceNormalization")
return buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Min")
return buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "Mod")
return buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Mul")
return buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Multinomial")
return buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Neg")
return buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "NonMaxSuppression")
return buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
if (opName == "NonZero")
return buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Not")
return buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "OneHot")
return buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Or")
return buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "PRelu")
return buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Pad")
return ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Pow")
return buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "QLinearConv")
return buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
if (opName == "QLinearMatMul")
return buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
if (opName == "QuantizeLinear")
return buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "RNN")
return buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
if (opName == "RandomNormal")
return buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "RandomNormalLike")
return buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "RandomUniform")
return buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "RandomUniformLike")
return buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Range")
return buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Reciprocal")
return buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceL1")
return buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceL2")
return buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceLogSum")
return buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceLogSumExp")
return buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMax")
return buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMean")
return buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMin")
return buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceProd")
return buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceSum")
return buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceSumSquare")
return buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Relu")
return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Reshape")
return buildOperation<mlir::ONNXReshapeOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Resize")
return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "ReverseSequence")
return buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "RoiAlign")
return buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Round")
return buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Scan")
return buildOperation<mlir::ONNXScanOp>(node);
if (opName == "Scatter")
return buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ScatterElements")
return buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ScatterND")
return buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Selu")
return buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "SequenceAt")
return buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "SequenceConstruct")
return buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "SequenceEmpty")
return buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "SequenceErase")
return buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "SequenceInsert")
return buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "SequenceLength")
return buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Shape")
return buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Shrink")
return buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sigmoid")
return buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sign")
return buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sin")
return buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sinh")
return buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Size")
return buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Slice")
return buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
if (opName == "Softmax")
return buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Softplus")
return buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Softsign")
return buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "SpaceToDepth")
return buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Split")
return buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
if (opName == "SplitToSequence")
return buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Sqrt")
return buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Squeeze")
return buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "StringNormalizer")
return buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sub")
return buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Sum")
return buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "Tan")
return buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Tanh")
return buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "TfIdfVectorizer")
return buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ThresholdedRelu")
return buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Tile")
return buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "TopK")
return buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
if (opName == "Transpose")
return buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Unique")
return buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
if (opName == "Unsqueeze")
return buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Upsample")
return buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Where")
return buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Xor")
return buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);

View File

@ -8,404 +8,11 @@
// Krnl IR and standard operations.
//
//===----------------------------------------------------------------------===//
#include <map>
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "src/dialect/krnl/krnl_helper.hpp"
#include "src/dialect/krnl/krnl_ops.hpp"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "src/pass/passes.hpp"
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
//===----------------------------------------------------------------------===//
// FrontendToAffine RewritePatterns
//===----------------------------------------------------------------------===//
/// Check is all dimensions are known at compile time.
static bool hasAllConstantDimensions(MemRefType type) {
auto memRefShape = type.getShape();
for (int i = 0; i < memRefShape.size(); ++i)
if (memRefShape[i] < 0)
return false;
return true;
}
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
static MemRefType convertToMemRefType(Type type) {
MemRefType memRefType;
auto tensorType = type.dyn_cast<TensorType>();
if (tensorType) {
assert(tensorType.hasRank() && "expected only ranked shapes");
memRefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
} else {
memRefType = type.dyn_cast<MemRefType>();
}
return memRefType;
}
/// Insert an allocation and deallocation for the given MemRefType.
static Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter,
bool insertDealloc,
ArrayRef<Value> operands = {}) {
// Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc;
if (!operands.empty()) {
auto memRefShape = type.getShape();
auto rank = memRefShape.size();
std::map<int, Value> fromOperands;
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
int memRefDimIdx = rank - 1 - reversedIdx;
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
Value maxDim = nullptr;
for (int i = 0; i < operands.size(); i++) {
auto operandShape =
operands[i].getType().cast<MemRefType>().getShape();
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
if (operandDimIdx < 0)
continue;
// In case of operations with broadcasting, the dimension of the
// alloc result is the maximum size along each dimension of the
// operands.
auto operandDim =
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
if (maxDim) {
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
operandDim, maxDim);
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
maxDim);
} else {
maxDim = operandDim;
}
}
fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
}
}
SmallVector<Value, 4> allocOperands;
for (int i = 0; i < rank; ++i)
if (memRefShape[i] < 0)
allocOperands.push_back(fromOperands[i]);
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
} else {
alloc = rewriter.create<AllocOp>(loc, type);
}
// Make sure to allocate at the beginning of the block if
// all dimensions are known.
auto *parentBlock = alloc.getOperation()->getBlock();
if (hasAllConstantDimensions(type))
alloc.getOperation()->moveBefore(&parentBlock->front());
if (insertDealloc) {
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
return alloc;
}
// Determine if current function returns the result value of the
// current op being lowered. If it does then dealloc should not be
// inserted.
static bool checkInsertDealloc(Operation *currentOp) {
auto parentBlock = currentOp->getBlock();
bool insertDealloc = true;
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
assert(currentOp->getNumResults() < 2 &&
"No more than one result supported (for now).");
// If there is at least one result to investigate.
if (currentOp->getNumResults() > 0) {
auto result = currentOp->getResult(0);
for (const auto &operand : op.getOperands())
if (operand == result)
insertDealloc = false;
}
});
return insertDealloc;
}
// Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input
// type.
std::map<int64_t, int64_t>
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
std::map<int64_t, int64_t> OutInDimMap;
int64_t rank = inputTy.getRank();
// Mark reduction axes.
std::vector<bool> isReductionAxis;
for (decltype(rank) i = 0; i < rank; ++i) {
if (std::find(axes.begin(), axes.end(), i) != axes.end())
isReductionAxis.push_back(true);
else
isReductionAxis.push_back(false);
}
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
// If it is a reduction axis, there is no relationship among dimensions.
if (isReductionAxis[inIndex]) {
if (keepdims)
outIndex++;
} else {
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
outIndex++;
}
}
return OutInDimMap;
}
// Add bounds associated with the op operand to the KRNL iteration pack.
// Dynamic dimenions are supported.
static void addDimensionToPack(ConversionPatternRewriter &rewriter,
Location loc, KrnlIterateOperandPack &pack,
Value operand, int index) {
auto shape = operand.getType().cast<MemRefType>().getShape();
if (shape[index] < 0) {
pack.pushConstantBound(0);
pack.pushOperandBound(
rewriter.create<DimOp>(loc, operand, index).getResult());
} else {
pack.pushConstantBound(0);
pack.pushConstantBound(shape[index]);
}
}
// Function that defines the KRNL dialect loops and their respective
// optimized version.
static KrnlOptimizeLoopsOp
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops, int64_t numLoops) {
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
loops.reserve(numLoops);
for (auto result : loopsOp.getResults())
loops.push_back(result);
// Define optimized version of the loops.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
optimizedLoops.reserve(numLoops);
for (auto result : optimizedLoopsOp.getResults())
optimizedLoops.push_back(result);
return optimizedLoopsOp;
}
// Function that emits the loops and their optimized version.
// The function returns a reference to the inner optimization block.
static Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops,
int64_t numLoops) {
KrnlOptimizeLoopsOp optimizedLoopsOp =
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
return &optimizedLoopsOp.region().front();
}
// Function which emits a basic set of loops and optimized loops
// for a given operation argument. A reference to the loop optimization
// block is returned in the last argument of the function.
static void emitKrnlLoopsAndIterationForOperand(
ConversionPatternRewriter &rewriter, Location loc, Value operand,
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
KrnlIterateOp &iterateOp) {
// Operand shape.
auto shape = operand.getType().cast<MemRefType>().getShape();
// Number of loops.
int64_t rank = shape.size();
// Define loops and optimized loops.
std::vector<Value> optimizedLoops;
optimizedLoopsOp =
emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest.
for (int i = 0; i < rank; ++i)
addDimensionToPack(rewriter, loc, pack, operand, i);
iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
}
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
auto elementType = memRefType.getElementType();
unsigned sizeInBits;
if (elementType.isIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
} else {
auto vectorType = elementType.cast<VectorType>();
sizeInBits =
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
}
return llvm::divideCeil(sizeInBits, 8);
}
// Get run-time dimension information for unknown dimensions used for
// broadcasting.
std::map<int, std::map<int, Value>>
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
MemRefType memRefType, ArrayRef<Value> operands) {
auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size();
// For unknown dimensions, we need to get dimension values at runtime in
// order to do broadcasting.
std::map<int, std::map<int, Value>> DimInfo;
// For each result dimension, compute the number of sharing operands.
// Sharing operands are operands sharing the same index (counting from the
// rightmost to the leftmost) for a given dimension.
std::map<int, int> sharedDimCount;
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
int dimIdx = rank - 1 - reversedIdx;
sharedDimCount[dimIdx] = 0;
for (int i = 0; i < operands.size(); ++i) {
auto shape = operands[i].getType().cast<MemRefType>().getShape();
if (reversedIdx <= shape.size() - 1)
sharedDimCount[dimIdx]++;
}
}
// An unknown dimension can have a value of 1 or N (N > 1).
// If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension.
// We only care about unknown dimensions whose number of sharing operands is
// more than one, since they are potentially broadcasted dimensions.
for (int i = 0; i < operands.size(); ++i) {
std::map<int, Value> broadcastedDims;
auto shape = operands[i].getType().cast<MemRefType>().getShape();
int size = shape.size();
for (int j = 0; j < shape.size(); ++j) {
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
auto isBroadcasted =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
}
}
DimInfo.insert(std::make_pair(i, broadcastedDims));
}
return DimInfo;
}
// Extract induction variables that are used for broadcasting values of a
// given operand.
std::vector<Value>
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value> broadcastedDims) {
// `operand` must has a ranked type. This should have been checked by the
// shape inference pass.
auto operandShape = operand.getType().cast<MemRefType>().getShape();
auto rank = operandShape.size();
auto loopCount = loopIVs.size();
std::vector<Value> newLoopIVs;
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
auto dimIdx = rank - 1 - reversedIdx;
auto loopIdx = loopCount - 1 - reversedIdx;
if (operandShape[dimIdx] == 1) {
// Broadcasted dimension
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
newLoopIVs.insert(newLoopIVs.begin(), zero);
} else if ((operandShape[dimIdx] == -1) &&
(broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
// Unknown dimension, it can have a value of 1 or N (N > 1).
// If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension.
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
loopIVs[loopIdx]);
newLoopIVs.insert(newLoopIVs.begin(), idx);
} else {
// Non-broadcasted dimension
newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
}
}
return newLoopIVs;
}
namespace {
// This is to get a scalar operation of a given type for a specific operation.
template <typename Op>
struct ScalarOp {
using FOp = void;
using IOp = void;
};
template <typename FOp>
using ScalarFOp = typename ScalarOp<FOp>::FOp;
template <typename IOp>
using ScalarIOp = typename ScalarOp<IOp>::IOp;
// Get the identity element of a operation.
// Return NULL if the function does not have identity.
template <typename DataType, typename Op>
DataType getIdentityValue() {
return NULL;
}
//===----------------------------------------------------------------------===//
// This is used in the innermost loop of a KrnlIterateOp to insert computation
// composed of one or many scalar ops.
// Use template specialization for each of different ONNX operations.
//===----------------------------------------------------------------------===//
template <typename Op>
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();
Type element_type = operands.front().getType();
if (element_type.isa<IntegerType>()) {
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
mlir::None);
} else if (element_type.isa<FloatType>()) {
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
mlir::None);
} else {
emitError(loc, "unsupported element type");
return nullptr;
}
}
// We divide the operator lowering into different categories.
// These categories are mostly similar to the operator categories in ONNX:
// https://github.com/onnx/onnx/tree/master/onnx/defs.
// Besides, it is better to put operators with the same computation pattern into
// the same category, e.g. element-wise operators will belong to the elementwise
// category.
// Math
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc"
// Tensor
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
// Neural network
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc"
//===----------------------------------------------------------------------===//
// EntryPoint Op lowering to Krnl Entry Point.
//===----------------------------------------------------------------------===//
@ -427,39 +34,6 @@ public:
}
};
//===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type.
//===----------------------------------------------------------------------===//
struct TensorTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
TensorTypeConverter() {
addConversion(convertType);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
if (auto type = convertToMemRefType(t)) {
results.push_back(type);
return success();
}
results.push_back(t);
return success();
}
/// Return true if the inputs and outputs of the given function type are
/// legal. [Taken from MLIR and adapted to only check the legality of the
/// inputs. Once unranked results can be handled gracefully this
/// override needs to be removed in favour of the original MLIR one.]
bool isSignatureLegal(FunctionType funcType) {
return llvm::all_of(funcType.getInputs(),
[this](Type type) { return isLegal(type); });
}
};
} // end anonymous namespace.
//===----------------------------------------------------------------------===//
// Frontend to Krnl Dialect lowering pass
//===----------------------------------------------------------------------===//

View File

@ -1,4 +1,4 @@
//===----- elementwise.inc - Elementwise Ops ------------------------------===//
//===----- elementwise.cpp - Elementwise Ops ------------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -8,6 +8,10 @@
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
template <>
struct ScalarOp<ONNXAddOp> {
using FOp = AddFOp;

View File

@ -1,4 +1,4 @@
//===----- gemm.inc - Lowering Gemm Op ------------------------------------===//
//===----- gemm.cpp - Lowering Gemm Op ------------------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -8,6 +8,10 @@
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
template <typename GemmOp>
struct ONNXGemmOpLowering : public ConversionPattern {
ONNXGemmOpLowering(MLIRContext *ctx)
@ -17,19 +21,21 @@ struct ONNXGemmOpLowering : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
auto has_bias = (operands.size() == 3);
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
Value A, B, C;
A = operands[0];
B = operands[1];
if (has_bias)
if (hasBias)
C = operands[2];
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
auto alphaAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
auto betaAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -68,8 +74,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// Define loops.
std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
optimizedLoops, numLoops);
Block *optimizationBlock =
defineLoops(rewriter, loc, originalLoops, optimizedLoops, numLoops);
// We have two Krnl loops:
// - Outer loop iterates over the output matrix dimensions, and
@ -83,8 +89,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
optimizedOuterLoops);
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
// Induction variables for the outer loops
for (int i = 0; i < 2; ++i)
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
@ -107,8 +112,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
reductionPack.pushConstantBound(0);
if (ATy.getShape()[K_A_Idx] != -1)
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
else
if (BTy.getShape()[K_B_Idx] != -1)
else if (BTy.getShape()[K_B_Idx] != -1)
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
else
reductionPack.pushOperandBound(
@ -119,7 +123,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// GemmOp supports unidirectional broadcasting from C to A*B.
// Hence, it must be enough to get broadcasting information for C only.
std::map<int, Value> broadcastedDimInfo;
if (has_bias) {
if (hasBias) {
auto shape = C.getType().cast<MemRefType>().getShape();
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) {
@ -162,7 +166,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
if (has_bias) {
if (hasBias) {
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
@ -210,8 +214,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
}
};
void populateLoweringONNXGemmOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
}

View File

@ -1,4 +1,4 @@
//===----- matmul.inc - Lowering Matmul Op --------------------------------===//
//===----- matmul.cpp - Lowering Matmul Op --------------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -8,6 +8,10 @@
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXMatMulOpLowering : public ConversionPattern {
ONNXMatMulOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,4 @@
//===----- reduction.inc - Lowering Reduction Ops -------------------------===//
//===----- reduction.cpp - Lowering Reduction Ops -------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -8,6 +8,10 @@
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
// Identity values
template <>
float getIdentityValue<float, ONNXReduceMaxOp>(){

View File

@ -1,4 +1,4 @@
//===----- softmax.inc - Softmax Op ---------------------------------------===//
//===----- softmax.cpp - Softmax Op ---------------------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -8,6 +8,10 @@
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXSoftmaxOpLowering : public ConversionPattern {
ONNXSoftmaxOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,4 @@
//===----- conv.inc - Lowering Convolution Op -----------------------------===//
//===----- conv.cpp - Lowering Convolution Op -----------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -8,12 +8,15 @@
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXConvNoBiasOpLowering : public ConversionPattern {
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
@ -25,12 +28,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
{operands[0]});
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {operands[0]});
auto resultShape = memRefType.getShape();
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
auto kernelShape = operands[1].getType().cast<MemRefType>().getShape();
auto &inputOperand = operands[0];
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
auto &kernelOperand = operands[1];
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
// R = ConvNoBias(D, K)
//
@ -91,123 +96,82 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
loc, FloatAttr::get(memRefType.getElementType(), 0));
Value subchannels;
if (kernelShape[1] < 0) {
subchannels =
rewriter.create<DimOp>(loc, operands[1], 1).getResult();
subchannels = rewriter.create<DimOp>(loc, kernelOperand, 1).getResult();
} else {
subchannels = rewriter.create<ConstantIndexOp>(
loc, kernelShape[1]);
subchannels = rewriter.create<ConstantIndexOp>(loc, kernelShape[1]);
}
// 1. Define outer loops and emit empty optimization block:
int64_t nOuterLoops = (group > 1) ? 3 : 2;
std::vector<Value> outerLoops;
std::vector<Value> optimizedOuterLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops,
optimizedOuterLoops, nOuterLoops);
// Prepare iteration arguments over outer loop nest.
KrnlIterateOperandPack pack(
rewriter, outerLoops, optimizedOuterLoops);
BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops);
outerLoops.createDefineAndOptimizeOp();
// for n = 0 .. N:
pack.pushConstantBound(0);
if (inputShape[0] < 0)
pack.pushOperandBound(
rewriter.create<DimOp>(loc, operands[0], 0).getResult());
else
pack.pushConstantBound(inputShape[0]);
int nIndex = outerLoops.pushBounds(0, inputOperand, 0);
// for g = 0 .. N:
if (group > 1) {
pack.pushConstantBound(0);
pack.pushConstantBound(group);
}
int gIndex = -1;
if (group > 1)
gIndex = outerLoops.pushBounds(0, group);
// for m = 0 .. kernelsPerGroup:
pack.pushConstantBound(0);
pack.pushConstantBound(kernelsPerGroup);
// Outer loop iteration.
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block &outerIterationBlock = iterateOp.bodyRegion().front();
// Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, outerLoops);
rewriter.setInsertionPointToStart(&outerIterationBlock);
int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
// Outer loop iteration
outerLoops.createIterateOp();
rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
{
// 2. Emit the body of the outer loop nest.
// 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
// If group is not set then the value of the kernel ID is
// identical to that of the loop over kernels.
Value kernel = outerIterationBlock.getArguments()[1];
Value kernel = outerLoops.getInductionVar(mIndex);
if (group > 1) {
// Middle loop is over groups and third loop is over the
// kernel identifiers in the current group.
auto kernelsOffset = rewriter.create<MulIOp>(loc,
outerIterationBlock.getArguments()[1],
kernelsPerGroupValue);
kernel = rewriter.create<AddIOp>(loc, kernelsOffset,
outerIterationBlock.getArguments()[2]);
auto kernelsOffset = rewriter.create<MulIOp>(
loc, outerLoops.getInductionVar(gIndex), kernelsPerGroupValue);
kernel = rewriter.create<AddIOp>(
loc, kernelsOffset, outerLoops.getInductionVar(mIndex));
}
// 2.2 Define spatial loops
int64_t nSpatialLoops = resultShape.size() - 2;
std::vector<Value> spatialLoops;
std::vector<Value> optimizedSpatialLoops;
Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops,
optimizedSpatialLoops, nSpatialLoops);
// 2.3 Prepare iteration arguments for spatial loop nest.
KrnlIterateOperandPack spatialPack(
rewriter, spatialLoops, optimizedSpatialLoops);
BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops);
spatialLoops.createDefineAndOptimizeOp();
for (int i = 2; i < resultShape.size(); ++i)
addDimensionToPack(rewriter, loc, spatialPack, alloc, i);
spatialLoops.pushBounds(0, alloc, i);
// 2.4 Emit loop nest over output spatial dimensions.
// for rX = 0 .. RX
auto spatialIterateOp =
rewriter.create<KrnlIterateOp>(loc, spatialPack);
Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front();
// 2.5 Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optSpatialLoopBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, spatialLoops);
rewriter.setInsertionPointToStart(&spatialIterationBlock);
spatialLoops.createIterateOp();
rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock());
{
// 3. Emit the body of the spatial loop nest.
// 3.1 Emit: R[n][kernel][r1][r2] = 0;
SmallVector<Value, 4> resultIndices;
// n
resultIndices.emplace_back(outerIterationBlock.getArguments()[0]);
resultIndices.emplace_back(outerLoops.getInductionVar(nIndex));
// kernel
resultIndices.emplace_back(kernel);
// rX
for (auto arg : spatialIterationBlock.getArguments())
for (auto arg : spatialLoops.getIterateBlock()->getArguments())
resultIndices.emplace_back(arg);
// Store initializer value into output location.
rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
// 3.2 Define inner loops.
int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
std::vector<Value> innerLoops;
std::vector<Value> optimizedInnerLoops;
Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops,
optimizedInnerLoops, nInnerLoops);
// 3.3 Prepare iteration arguments for inner loop nest.
KrnlIterateOperandPack innerPack(
rewriter, innerLoops, optimizedInnerLoops);
BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops);
innerLoops.createDefineAndOptimizeOp();
// for c = 0 .. C/group
innerPack.pushConstantBound(0);
innerPack.pushConstantBound(kernelShape[1]);
int cIndex = innerLoops.pushBounds(0, kernelShape[1]);
// for Kx = 0 .. KX
for (int i = 2; i < kernelShape.size(); ++i)
addDimensionToPack(rewriter, loc, innerPack, operands[1], i);
innerLoops.pushBounds(0, kernelOperand, i);
// 3.4 Emit inner loop nest.
auto innerIterateOp =
rewriter.create<KrnlIterateOp>(loc, innerPack);
Block &innerIterationBlock = innerIterateOp.bodyRegion().front();
// 3.5 Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optInnerLoopBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, innerLoops);
rewriter.setInsertionPointToStart(&innerIterationBlock);
innerLoops.createIterateOp();
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
{
// 4. Emit inner loop body
// R[n][kernel][r1][r2] =
@ -217,13 +181,13 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
// 4.1 Prepare indices for accesing the data tensor.
SmallVector<Value, 4> dataIndices;
// n
dataIndices.emplace_back(outerIterationBlock.getArguments()[0]);
dataIndices.emplace_back(outerLoops.getInductionVar(nIndex));
// g * (C / group) + c
Value channelDepth = innerIterationBlock.getArguments()[0];
Value channelDepth = innerLoops.getInductionVar(cIndex);
if (group > 1)
channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
rewriter.create<MulIOp>(loc, subchannels,
outerIterationBlock.getArguments()[1]));
rewriter.create<MulIOp>(
loc, subchannels, outerLoops.getInductionVar(gIndex)));
dataIndices.emplace_back(channelDepth);
// sX * rX + kX
auto stridesAttribute = convOp.stridesAttr();
@ -233,15 +197,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
for (auto stride : stridesAttribute.getValue())
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
for (int i = 0; i < kernelShape.size() - 2; ++i) {
Value spatialIndex = spatialIterationBlock.getArguments()[i];
Value spatialIndex = spatialLoops.getInductionVar(i);
// If strides are present then emit the correct access index.
if (stridesAttribute && strides[i] > 1)
spatialIndex = rewriter.create<MulIOp>(loc,
rewriter.create<ConstantIndexOp>(loc, strides[i]),
spatialIterationBlock.getArguments()[i]);
dataIndices.emplace_back(
rewriter.create<AddIOp>(loc, spatialIndex,
innerIterationBlock.getArguments()[i+1]));
spatialLoops.getInductionVar(i));
dataIndices.emplace_back(rewriter.create<AddIOp>(
loc, spatialIndex, innerLoops.getInductionVar(i + 1)));
}
// 4.2 Prepare indices for accessing the kernel tensor.
@ -249,17 +212,16 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
// kernel
kernelIndices.emplace_back(kernel);
// c
kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]);
kernelIndices.emplace_back(innerLoops.getInductionVar(cIndex));
// kX
for (int i = 0; i < kernelShape.size() - 2; ++i)
kernelIndices.emplace_back(
innerIterationBlock.getArguments()[i+1]);
kernelIndices.emplace_back(innerLoops.getInductionVar(i + 1));
// 4.3 Compute convolution.
auto loadData =
rewriter.create<LoadOp>(loc, operands[0], dataIndices);
rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
auto loadKernel =
rewriter.create<LoadOp>(loc, operands[1], kernelIndices);
rewriter.create<LoadOp>(loc, kernelOperand, kernelIndices);
auto loadPartialSum =
rewriter.create<LoadOp>(loc, alloc, resultIndices);
Value result = rewriter.create<AddFOp>(loc, loadPartialSum,

View File

@ -1,4 +1,4 @@
//===----- normalization.inc - Lowering Normalization Ops -----------------===//
//===----- normalization.cpp - Lowering Normalization Ops -----------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -8,6 +8,10 @@
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
: ConversionPattern(

View File

@ -0,0 +1,324 @@
//====-- onnx_to_krnl_common.cpp - ONNX dialects to Krnl lowering ---------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file contains common code shared by the functions performing the
// lowering to the KRNL dialect.
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
/// Check is all dimensions are known at compile time.
bool hasAllConstantDimensions(MemRefType type) {
auto memRefShape = type.getShape();
for (int i = 0; i < memRefShape.size(); ++i)
if (memRefShape[i] < 0)
return false;
return true;
}
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
MemRefType convertToMemRefType(Type type) {
MemRefType memRefType;
auto tensorType = type.dyn_cast<TensorType>();
if (tensorType) {
assert(tensorType.hasRank() && "expected only ranked shapes");
memRefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
} else {
memRefType = type.dyn_cast<MemRefType>();
}
return memRefType;
}
/// Insert an allocation and deallocation for the given MemRefType.
Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter,
bool insertDealloc,
ArrayRef<Value> operands) {
// Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc;
if (!operands.empty()) {
auto memRefShape = type.getShape();
auto rank = memRefShape.size();
std::map<int, Value> fromOperands;
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
int memRefDimIdx = rank - 1 - reversedIdx;
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
Value maxDim = nullptr;
for (int i = 0; i < operands.size(); i++) {
auto operandShape =
operands[i].getType().cast<MemRefType>().getShape();
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
if (operandDimIdx < 0)
continue;
// In case of operations with broadcasting, the dimension of the
// alloc result is the maximum size along each dimension of the
// operands.
auto operandDim =
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
if (maxDim) {
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
operandDim, maxDim);
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
maxDim);
} else {
maxDim = operandDim;
}
}
fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
}
}
SmallVector<Value, 4> allocOperands;
for (int i = 0; i < rank; ++i)
if (memRefShape[i] < 0)
allocOperands.push_back(fromOperands[i]);
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
} else {
alloc = rewriter.create<AllocOp>(loc, type);
}
// Make sure to allocate at the beginning of the block if
// all dimensions are known.
auto *parentBlock = alloc.getOperation()->getBlock();
if (hasAllConstantDimensions(type))
alloc.getOperation()->moveBefore(&parentBlock->front());
if (insertDealloc) {
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
return alloc;
}
// Determine if current function returns the result value of the
// current op being lowered. If it does then dealloc should not be
// inserted.
bool checkInsertDealloc(Operation *currentOp) {
auto parentBlock = currentOp->getBlock();
bool insertDealloc = true;
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
assert(currentOp->getNumResults() < 2 &&
"No more than one result supported (for now).");
// If there is at least one result to investigate.
if (currentOp->getNumResults() > 0) {
auto result = currentOp->getResult(0);
for (const auto &operand : op.getOperands())
if (operand == result)
insertDealloc = false;
}
});
return insertDealloc;
}
// Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input
// type.
std::map<int64_t, int64_t>
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
std::map<int64_t, int64_t> OutInDimMap;
int64_t rank = inputTy.getRank();
// Mark reduction axes.
std::vector<bool> isReductionAxis;
for (decltype(rank) i = 0; i < rank; ++i) {
if (std::find(axes.begin(), axes.end(), i) != axes.end())
isReductionAxis.push_back(true);
else
isReductionAxis.push_back(false);
}
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
// If it is a reduction axis, there is no relationship among dimensions.
if (isReductionAxis[inIndex]) {
if (keepdims)
outIndex++;
} else {
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
outIndex++;
}
}
return OutInDimMap;
}
// Add bounds associated with the op operand to the KRNL iteration pack.
// Dynamic dimenions are supported.
void addDimensionToPack(ConversionPatternRewriter &rewriter,
Location loc, KrnlIterateOperandPack &pack,
Value operand, int index) {
auto shape = operand.getType().cast<MemRefType>().getShape();
if (shape[index] < 0) {
pack.pushConstantBound(0);
pack.pushOperandBound(
rewriter.create<DimOp>(loc, operand, index).getResult());
} else {
pack.pushConstantBound(0);
pack.pushConstantBound(shape[index]);
}
}
// Function that defines the KRNL dialect loops and their respective
// optimized version.
KrnlOptimizeLoopsOp
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops, int64_t numLoops) {
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
loops.reserve(numLoops);
for (auto result : loopsOp.getResults())
loops.push_back(result);
// Define optimized version of the loops.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
optimizedLoops.reserve(numLoops);
for (auto result : optimizedLoopsOp.getResults())
optimizedLoops.push_back(result);
return optimizedLoopsOp;
}
// Function that emits the loops and their optimized version.
// The function returns a reference to the inner optimization block.
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops,
int64_t numLoops) {
KrnlOptimizeLoopsOp optimizedLoopsOp =
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
return &optimizedLoopsOp.region().front();
}
// Function which emits a basic set of loops and optimized loops
// for a given operation argument. A reference to the loop optimization
// block is returned in the last argument of the function.
void emitKrnlLoopsAndIterationForOperand(
ConversionPatternRewriter &rewriter, Location loc, Value operand,
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
KrnlIterateOp &iterateOp) {
// Operand shape.
auto shape = operand.getType().cast<MemRefType>().getShape();
// Number of loops.
int64_t rank = shape.size();
// Define loops and optimized loops.
std::vector<Value> optimizedLoops;
optimizedLoopsOp =
emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest.
for (int i = 0; i < rank; ++i)
addDimensionToPack(rewriter, loc, pack, operand, i);
iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
}
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
auto elementType = memRefType.getElementType();
unsigned sizeInBits;
if (elementType.isIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
} else {
auto vectorType = elementType.cast<VectorType>();
sizeInBits =
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
}
return llvm::divideCeil(sizeInBits, 8);
}
// Get run-time dimension information for unknown dimensions used for
// broadcasting.
std::map<int, std::map<int, Value>>
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
MemRefType memRefType, ArrayRef<Value> operands) {
auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size();
// For unknown dimensions, we need to get dimension values at runtime in
// order to do broadcasting.
std::map<int, std::map<int, Value>> DimInfo;
// For each result dimension, compute the number of sharing operands.
// Sharing operands are operands sharing the same index (counting from the
// rightmost to the leftmost) for a given dimension.
std::map<int, int> sharedDimCount;
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
int dimIdx = rank - 1 - reversedIdx;
sharedDimCount[dimIdx] = 0;
for (int i = 0; i < operands.size(); ++i) {
auto shape = operands[i].getType().cast<MemRefType>().getShape();
if (reversedIdx <= shape.size() - 1)
sharedDimCount[dimIdx]++;
}
}
// An unknown dimension can have a value of 1 or N (N > 1).
// If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension.
// We only care about unknown dimensions whose number of sharing operands is
// more than one, since they are potentially broadcasted dimensions.
for (int i = 0; i < operands.size(); ++i) {
std::map<int, Value> broadcastedDims;
auto shape = operands[i].getType().cast<MemRefType>().getShape();
int size = shape.size();
for (int j = 0; j < shape.size(); ++j) {
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
auto isBroadcasted =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
}
}
DimInfo.insert(std::make_pair(i, broadcastedDims));
}
return DimInfo;
}
// Extract induction variables that are used for broadcasting values of a
// given operand.
std::vector<Value>
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value> broadcastedDims) {
// `operand` must has a ranked type. This should have been checked by the
// shape inference pass.
auto operandShape = operand.getType().cast<MemRefType>().getShape();
auto rank = operandShape.size();
auto loopCount = loopIVs.size();
std::vector<Value> newLoopIVs;
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
auto dimIdx = rank - 1 - reversedIdx;
auto loopIdx = loopCount - 1 - reversedIdx;
if (operandShape[dimIdx] == 1) {
// Broadcasted dimension
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
newLoopIVs.insert(newLoopIVs.begin(), zero);
} else if ((operandShape[dimIdx] == -1) &&
(broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
// Unknown dimension, it can have a value of 1 or N (N > 1).
// If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension.
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
loopIVs[loopIdx]);
newLoopIVs.insert(newLoopIVs.begin(), idx);
} else {
// Non-broadcasted dimension
newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
}
}
return newLoopIVs;
}

View File

@ -0,0 +1,217 @@
//====-- onnx_to_krnl_common.hpp - ONNX dialects to Krnl lowering ---------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file contains common code shared by the functions performing the
// lowering to the KRNL dialect.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <map>
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "mlir/IR/PatternMatch.h"
#include "src/dialect/krnl/krnl_helper.hpp"
#include "src/dialect/krnl/krnl_ops.hpp"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "src/pass/passes.hpp"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Common functions used when lowering the ONNX frontend dialect to KRNL.
//===----------------------------------------------------------------------===//
/// Check is all dimensions are known at compile time.
bool hasAllConstantDimensions(MemRefType type);
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
MemRefType convertToMemRefType(Type type);
/// Insert an allocation and deallocation for the given MemRefType.
Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter,
bool insertDealloc,
ArrayRef<Value> operands = {});
// Determine if current function returns the result value of the
// current op being lowered. If it does then dealloc should not be
// inserted.
bool checkInsertDealloc(Operation *currentOp);
// Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input
// type.
std::map<int64_t, int64_t>
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
// Add bounds associated with the op operand to the KRNL iteration pack.
// Dynamic dimenions are supported.
void addDimensionToPack(ConversionPatternRewriter &rewriter,
Location loc, KrnlIterateOperandPack &pack,
Value operand, int index);
// Function that defines the KRNL dialect loops and their respective
// optimized version.
KrnlOptimizeLoopsOp
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops, int64_t numLoops);
// Function that emits the loops and their optimized version.
// The function returns a reference to the inner optimization block.
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops,
int64_t numLoops);
// Function which emits a basic set of loops and optimized loops
// for a given operation argument. A reference to the loop optimization
// block is returned in the last argument of the function.
void emitKrnlLoopsAndIterationForOperand(
ConversionPatternRewriter &rewriter, Location loc, Value operand,
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
KrnlIterateOp &iterateOp);
unsigned getMemRefEltSizeInBytes(MemRefType memRefType);
// Get run-time dimension information for unknown dimensions used for
// broadcasting.
std::map<int, std::map<int, Value>>
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
MemRefType memRefType, ArrayRef<Value> operands);
// Extract induction variables that are used for broadcasting values of a
// given operand.
std::vector<Value>
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value> broadcastedDims);
//===----------------------------------------------------------------------===//
// This is to get a scalar operation of a given type for a specific operation.
//===----------------------------------------------------------------------===//
template <typename Op>
struct ScalarOp {
using FOp = void;
using IOp = void;
};
template <typename FOp>
using ScalarFOp = typename ScalarOp<FOp>::FOp;
template <typename IOp>
using ScalarIOp = typename ScalarOp<IOp>::IOp;
// Get the identity element of a operation.
// Return NULL if the function does not have identity.
template <typename DataType, typename Op>
DataType getIdentityValue() {
return NULL;
}
//===----------------------------------------------------------------------===//
// This is used in the innermost loop of a KrnlIterateOp to insert computation
// composed of one or many scalar ops.
// Use template specialization for each of different ONNX operations.
//===----------------------------------------------------------------------===//
template <typename Op>
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();
Type element_type = operands.front().getType();
if (element_type.isa<IntegerType>()) {
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
mlir::None);
} else if (element_type.isa<FloatType>()) {
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
mlir::None);
} else {
emitError(loc, "unsupported element type");
return nullptr;
}
}
//===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type.
//===----------------------------------------------------------------------===//
struct TensorTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
TensorTypeConverter() {
addConversion(convertType);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
if (auto type = convertToMemRefType(t)) {
results.push_back(type);
return success();
}
results.push_back(t);
return success();
}
/// Return true if the inputs and outputs of the given function type are
/// legal. [Taken from MLIR and adapted to only check the legality of the
/// inputs. Once unranked results can be handled gracefully this
/// override needs to be removed in favour of the original MLIR one.]
bool isSignatureLegal(FunctionType funcType) {
return llvm::all_of(funcType.getInputs(),
[this](Type type) { return isLegal(type); });
}
};
//===----------------------------------------------------------------------===//
// Functions to add lowering patterns for frontend operations.
//===----------------------------------------------------------------------===//
// `math` directory methods:
void populateLoweringONNXElementwiseOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx);
void populateLoweringONNXMatMulOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXReductionOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXSoftmaxOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
// `nn` directory methods:
void populateLoweringONNXConvOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXNormalizationOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
// `tensor` directory methods:
void populateLoweringONNXUnsqueezeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXTransposeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXReshapeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXIdentityOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);

View File

@ -1,4 +1,4 @@
//===----- identity.inc - Lowering Identity Op ----------------------------===//
//===----- identity.cpp - Lowering Identity Op ----------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -8,6 +8,10 @@
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXIdentityOpLowering : public ConversionPattern {
ONNXIdentityOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,4 @@
//===----- reshape.inc - Lowering Reshape Op ------------------------------===//
//===----- reshape.cpp - Lowering Reshape Op ------------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -8,6 +8,10 @@
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXReshapeOpLowering : public ConversionPattern {
ONNXReshapeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,4 @@
//===----- transpose.inc - Lowering Transpose Op --------------------------===//
//===----- transpose.cpp - Lowering Transpose Op --------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -8,6 +8,10 @@
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXTransposeOpLowering : public ConversionPattern {
ONNXTransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,4 @@
//===----- unsqueeze.inc - Lowering Unsqueeze Op --------------------------===//
//===----- unsqueeze.cpp - Lowering Unsqueeze Op --------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -8,6 +8,10 @@
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,5 @@
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "src/dialect/krnl/krnl_ops.hpp"
@ -9,9 +10,8 @@ namespace onnf {
using namespace mlir;
ParseResult
KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
Value &operand) {
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
const Type &operandType, Value &operand) {
// If operand queue is empty, parse more operands and cache them.
if (_operandRefQueue.empty()) {
// Parse operand types:
@ -48,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
return success();
}
ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType,
Value &operand) {
ParseResult KrnlDialectOperandParser::ParseOperand(
const Type &operandType, Value &operand) {
if (ParseOptionalOperand(operandType, operand))
return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand.");
@ -123,6 +123,7 @@ void printBound(AffineMapAttr boundMap,
} // namespace onnf
namespace mlir {
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
if (boundMaps.size() % 2 == 0)
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
@ -130,11 +131,143 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
boundMaps.emplace_back(AffineMapAttr::get(map));
}
void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
void KrnlIterateOperandPack::pushOperandBound(Value operand) {
if (boundMaps.size() % 2 == 0)
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
AffineMap map = builder.getSymbolIdentityMap();
boundMaps.emplace_back(AffineMapAttr::get(map));
_operands.emplace_back(operand);
}
BuildKrnlLoop::BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, int loopNum)
: rewriter(rewriter), loc(loc), originalLoopNum(loopNum), pack(NULL),
pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
createdIterateOp(false) {
if (originalLoopNum <= 0)
emitError(loc, "Expected positive number of original loops.");
}
BuildKrnlLoop::BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand)
: BuildKrnlLoop(rewriter, loc,
memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
BuildKrnlLoop::~BuildKrnlLoop() {
if (pack)
free(pack);
}
void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
// Insert define loop operation.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
originalLoops.reserve(originalLoopNum);
for (auto result : loopsOp.getResults())
originalLoops.push_back(result);
createdDefineOp = true;
// Insert optimize loop operation.
auto optimizedLoopsOp =
rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
optLoops.reserve(originalLoopNum);
// Emit empty optimizations if flag is set.
if (withEmptyOptimization) {
for (auto result : optimizedLoopsOp.getResults())
optLoops.push_back(result);
optBlock = &optimizedLoopsOp.region().front();
auto ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToEnd(optBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.restoreInsertionPoint(ip);
}
createdOptimizeOp = true;
// prepare data structure to push bounds
pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
}
int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
pack->pushConstantBound(lowerBound);
pack->pushConstantBound(upperBound);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) {
pack->pushConstantBound(lowerBound);
pack->pushOperandBound(upperBound);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
pack->pushConstantBound(lowerBound);
// Process upperBound as a dimension of the MemRef. Non-constant dimensions
// are supported.
auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
if (shape[upperBoundMemRefIndex] < 0) {
if (upperBoundMustBeConstant)
emitError(loc, "Bound expected to be constant.");
pack->pushOperandBound(
rewriter
.create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
.getResult());
} else
pack->pushConstantBound(shape[upperBoundMemRefIndex]);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) {
pack->pushOperandBound(lowerBound);
pack->pushOperandBound(upperBound);
return pushCount++;
}
void BuildKrnlLoop::createIterateOp() {
// Loop definition operation is mandatory.
if (!createdDefineOp)
emitError(loc, "Must create define op before iterate op.");
// Loop optimization operation is mandatory (for now).
if (!createdOptimizeOp)
emitError(loc, "Must create optimize op before iterate op.");
// Check if all bounds have been defined.
if (pushCount != originalLoopNum)
emitError(loc, "Must push bounds for all original loops.");
// Emit iteration operation.
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
iterBlock = &iterateOp.bodyRegion().front();
createdIterateOp = true;
}
void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
Value memRefOperand, bool withEmptyOptimization) {
// Rank of the MemRef operand. We will emit a loop for each dimension.
int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
if (originalLoopNum != loopNum)
emitError(loc, "Mismatch in loop numbers from constructor and define.");
// Emit the definition and the optimization operations for the loop nest.
createDefineAndOptimizeOp(withEmptyOptimization);
// Push a lower-upper bound pair for each dimension of the MemRef operand.
// The lower bound in this case is always zero.
for (int i = 0; i < originalLoopNum; ++i)
pushBounds(0, memRefOperand, i);
// Emit the iteration operation over the current loop nest.
createIterateOp();
}
BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
// Check if loop iteration variable is within bounds.
if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
emitError(loc, "Original loop index is out of bounds.");
return iterBlock->getArguments()[originalLoopIndex];
}
} // namespace mlir

View File

@ -8,6 +8,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnf {
@ -17,21 +18,19 @@ class KrnlDialectOperandParser {
: _parser(parser), _builder(parser.getBuilder()){};
// Parse an optional operand.
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
mlir::Value &operand);
mlir::ParseResult ParseOptionalOperand(
const mlir::Type &operandType, mlir::Value &operand);
// Parse an optional operand and push it to an operand list.
mlir::ParseResult
ParseOptionalOperand(const mlir::Type &operandType,
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList);
// Parse a required operand.
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
mlir::Value &operand);
mlir::ParseResult ParseOperand(
const mlir::Type &operandType, mlir::Value &operand);
// Parse a required operand and push it to an operand list.
mlir::ParseResult
ParseOperand(const mlir::Type &operandType,
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList);
// Do we have more operands to parse?
@ -100,4 +99,121 @@ struct KrnlIterateOperandPack {
mlir::Builder &builder;
};
// Helper function to write kernel loops. This class will let us build a single
// define/optimize/iterate operation combo. We can then insert optimizations in
// the body of the optimization operation, and operations in the body of the
// iterate operation.
//
// The sequence is as follow:
//
// 1) Create an object giving the rewriter, location, and number of loop in
// the original (non optimized) loop.
//
// 2) Create define & optimize ops (currently paired). Optimizations can then
// be added to the inner block of the optimize operation. Make sure to set
// the insertion point to that block for optimizations to go in the right
// place.
//
// 3) Push the bounds for each of the original loops. Bounds are pushed in
// pairs (lower & upper bounds). There are a few methods to do it depending
// on the type of the bounds. When pushing bounds, the method returns a
// number that represent the index associated with that iteration (induction
// variable and bounds). That index can be used later to extract the
// induction variable for reference in computation and/or index calculations
// of mem refs.
//
// 4) Once all the bounds are pushed, create the iterate operation. Once this
// is done, we can add operations within the iterate blocks by setting the
// insertion point to it. Value of the induction variables can be retrieved
// using the proper index (determined when pushin the bounds).
class BuildKrnlLoop {
public:
// Create kernel loop builder for a loop nest of depth loopNum.
BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
// Create kernel loop builder for a loop nest of depth equal to the
// dimensionality of the operand. An operand of MemRef type is requied.
BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
~BuildKrnlLoop();
// Create define and optimize loop with loopNum original loops. If
// withEmptyOptimization is true, the optimization is simply the identity
// function (no optimizations).
void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
// Push bounds (lower and upper) for each of the loops (order matters).
// The function returns the order number associated with the loop iteration.
// This index is used by the getInductionVar call. Non-constant operands
// must be of MemRef type.
int pushBounds(int64_t lowerBound, int64_t upperBound);
int pushBounds(int64_t lowerBound, Value upperBound);
int pushBounds(Value lowerBound, Value upperBound);
int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
// Create the KrnlIterateOp assiciated with this loop nest. The loops
// iteration will be created if the definition and the optimization
// operations associated with this loop nest have been emitted already.
void createIterateOp();
// Create the loop nest definition, optimization and iteration operations
// for a given operand of MemRef type. The loop nest has a depth equal to the
// rank of the MemRef operand. The lower bound of each loop is zero. The
// upper bound of each loop is given by the corresponding dimension of the
// MemRef operand.
void createDefineOptimizeAndIterateOp(
Value memRefOperand, bool withEmptyOptimization = true);
// Get the (original loop) induction variable associated with the given
// index. Use the index returned when pushing the bounds.
BlockArgument &getInductionVar(int originalLoopIndex);
// Get a reference to the code region of the optimization operation.
// This allows us to set the insertion point to the inner block of the
// loop nest optimization operation.
Block *getOptimizationBlock() { return optBlock; }
// Get a reference to the code region of the iteration operation.
// This allows us to set the insertion point to the inner block of the
// loop nest iteration operation.
Block *getIterateBlock() { return iterBlock; }
// Get original loop nest.
std::vector<Value> &getOriginalLoops() { return originalLoops; }
// Get optimized loop nest.
std::vector<Value> &getOptimizedLoops() { return optLoops; }
private:
// Required for emitting operations.
ConversionPatternRewriter &rewriter;
Location loc;
int originalLoopNum;
// List of original, un-optimized loops.
std::vector<Value> originalLoops;
// List of optimized loops.
std::vector<Value> optLoops;
// List of lower-upper bound pairs needed by the KrnlIterateOp.
KrnlIterateOperandPack *pack;
// Number of lower-upper bound pairs pushed.
int pushCount;
// Flags that keep track of emitted operations.
bool createdDefineOp;
bool createdOptimizeOp;
bool createdIterateOp;
// Saved insertion point in the code region of the KrnlOptimizeLoopsOp.
Block *optBlock;
// Saved insertion point in the code region of the KrnlIterateOp.
Block *iterBlock;
};
} // namespace mlir

View File

@ -90,25 +90,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
// or outputs. This decision affects only ONNX operations with optional
// arguments not ONNX operations with variadic operands.
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation without bias.";
let description = [{
The "onnx.Gemm" generic matrix multiplication without bias.
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
DefaultValuedAttr<F32Attr, "1.0">:$alpha,
DefaultValuedAttr<F32Attr, "1.0">:$beta,
DefaultValuedAttr<I64Attr, "0">:$transA,
DefaultValuedAttr<I64Attr, "0">:$transB);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
}
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;

View File

@ -24,12 +24,29 @@
using namespace mlir;
using namespace mlir::OpTrait::util;
//===----------------------------------------------------------------------===//
// ONNX Helper functions
//===----------------------------------------------------------------------===//
static size_t ArrayAttrSize(ArrayAttr a) { return a.size(); }
static size_t ArrayAttrSize(Optional<ArrayAttr> a) {
return a.getValue().size();
}
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
}
static int64_t ArrayAttrIntVal(Optional<ArrayAttr> a, int i) {
return (a.getValue().getValue()[i]).cast<IntegerAttr>().getInt();
}
//===----------------------------------------------------------------------===//
// Get reduction type
//===----------------------------------------------------------------------===//
RankedTensorType getReductionOutputType(RankedTensorType operandTy,
Optional<ArrayAttr> axesAttrs,
APInt keepdims) {
RankedTensorType getReductionOutputType(
RankedTensorType operandTy, Optional<ArrayAttr> axesAttrs, APInt keepdims) {
int64_t rank = operandTy.getRank();
SmallVector<int64_t, 4> axes;
@ -87,8 +104,8 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
}
void ONNXEntryPointOp::build(mlir::Builder *builder,
mlir::OperationState &state, mlir::FuncOp function,
int numInputs, int numOutputs) {
mlir::OperationState &state, mlir::FuncOp function, int numInputs,
int numOutputs) {
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
builder->getSymbolRefAttr(function));
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
@ -98,8 +115,7 @@ void ONNXEntryPointOp::build(mlir::Builder *builder,
}
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
mlir::FuncOp &func, int numInputs,
int numOutputs) {
mlir::FuncOp &func, int numInputs, int numOutputs) {
mlir::OperationState state(location, "onnx.EntryPoint");
Builder builder(location->getContext());
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
@ -120,25 +136,19 @@ void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
// Tanh
/// Infer the output shape of the ONNXTanhOp. This method is required by the
/// shape inference interface.
void ONNXTanhOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Sinh
/// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface.
void ONNXSinhOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Cosh
/// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface.
void ONNXCoshOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Cos
@ -178,9 +188,7 @@ void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
// Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface.
void ONNXReluOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// LeakyRelu
@ -194,9 +202,7 @@ void ONNXLeakyReluOp::inferShapes() {
// Selu
/// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface.
void ONNXSeluOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Reciprocal
@ -234,17 +240,13 @@ void ONNXSoftsignOp::inferShapes() {
// Sqrt
/// Infer the output shape of the ONNXSqrtOp. This method is required by
/// the shape inference interface.
void ONNXSqrtOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Sign
/// Infer the output shape of the ONNXSignOp. This method is required by
/// the shape inference interface.
void ONNXSignOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Add
@ -404,12 +406,12 @@ void ONNXIdentityOp::inferShapes() {
void ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>())
return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims;
auto lhsShape = lhsTy.getShape();
@ -417,15 +419,14 @@ void ONNXMatMulOp::inferShapes() {
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
// Multiplication by scalars is not allowed.
emitError("Multiplication by scalar arguments not allowed.");
emitError("Multiplication by scalar arguments not allowed");
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
// Special case when both arrays are 1-dimensional and according to
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all
// sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 &&
lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices");
dims.emplace_back(1);
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
// If the first argument is 1-D, it is promoted to a matrix by prepending a
@ -440,7 +441,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size();
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[0] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]);
@ -458,7 +459,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
dims.emplace_back(lhsShape[i]);
@ -472,7 +473,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
dims.emplace_back(lhsShape[i]);
@ -486,7 +487,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size();
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]);
@ -502,7 +503,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
// Check and perform broadcasting for the shapes.
SmallVector<int64_t, 2> lhsBcastShape;
@ -512,7 +513,7 @@ void ONNXMatMulOp::inferShapes() {
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
rhsBcastShape.emplace_back(rhsShape[i]);
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
emitError("Broadcasted dimensions are incompatible.");
emitError("Broadcasted dimensions are incompatible");
dims.emplace_back(lhsShape[lhsRank - 2]);
dims.emplace_back(rhsShape[rhsRank - 1]);
@ -527,7 +528,7 @@ void ONNXMatMulOp::inferShapes() {
// Check legality of matrix multiplication.
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
if (rhsShape.size() > 1)
dims.emplace_back(rhsShape[1]);
@ -541,14 +542,14 @@ void ONNXMatMulOp::inferShapes() {
// Gemm
void ONNXGemmOp::inferShapes() {
bool hasBias = !C().getType().isa<NoneType>();
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>())
if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() ||
(hasBias && !C().getType().isa<RankedTensorType>()))
return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
@ -557,44 +558,21 @@ void ONNXGemmOp::inferShapes() {
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
emitError("Tensor shapes mismatched.");
emitError("Tensor shapes mismatched");
}
if (hasBias) {
// Check whether bias is unidirectional broadcasting or not.
auto biasTy = C().getType().cast<RankedTensorType>();
auto shape = biasTy.getShape();
int rank = shape.size();
if ((rank > 2) ||
(rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] &&
shape[rank - 1] != 1) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] &&
shape[rank - 2] != 1)) {
emitError("Bias shape mismatched.");
(rank >= 1 && shape[rank - 1] != -1 && N != -1 &&
N != shape[rank - 1] && shape[rank - 1] != 1) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
M != shape[rank - 2] && shape[rank - 2] != 1)) {
emitError("Bias shape mismatched");
}
SmallVector<int64_t, 2> dims;
dims.emplace_back(M);
dims.emplace_back(N);
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
// GemmNoBias
void ONNXGemmNoBiasOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
emitError("Tensor shapes mismatched.");
}
SmallVector<int64_t, 2> dims;
@ -606,50 +584,50 @@ void ONNXGemmNoBiasOp::inferShapes() {
/// BatchNormalizationTestMode
void ONNXBatchNormalizationTestModeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>() ||
!getOperand(3).getType().isa<RankedTensorType>() ||
!getOperand(4).getType().isa<RankedTensorType>())
if (!X().getType().isa<RankedTensorType>() ||
!scale().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() ||
!mean().getType().isa<RankedTensorType>() ||
!var().getType().isa<RankedTensorType>())
return;
auto input = getOperand(0).getType().cast<RankedTensorType>();
auto scale = getOperand(1).getType().cast<RankedTensorType>();
auto bias = getOperand(2).getType().cast<RankedTensorType>();
auto mean = getOperand(3).getType().cast<RankedTensorType>();
auto variance = getOperand(4).getType().cast<RankedTensorType>();
auto inputTensorTy = X().getType().cast<RankedTensorType>();
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
auto biasTensorTy = B().getType().cast<RankedTensorType>();
auto meanTensorTy = mean().getType().cast<RankedTensorType>();
auto varianceTensorTy = var().getType().cast<RankedTensorType>();
// Check whether the shapes of scale, bias, mean and variance are valid.
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1.
// Shapes of scale, bias, mean and variance must be C.
int64_t c = -1;
if (input.getShape().size() == 1) {
if (inputTensorTy.getShape().size() == 1) {
c = 1;
} else if (input.getShape().size() > 2) {
c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1;
} else if (inputTensorTy.getShape().size() > 2) {
c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1;
} else {
emitError("Wrong rank for the input.");
emitError("Wrong rank for the input");
}
if (c != -1) {
auto s = scale.getShape();
auto b = bias.getShape();
auto m = mean.getShape();
auto v = variance.getShape();
auto s = scaleTensorTy.getShape();
auto b = biasTensorTy.getShape();
auto m = meanTensorTy.getShape();
auto v = varianceTensorTy.getShape();
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
emitError("Wrong rank for the scale.");
emitError("Wrong rank for the scale");
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
emitError("Wrong rank for the bias.");
emitError("Wrong rank for the bias");
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
emitError("Wrong rank for the mean.");
emitError("Wrong rank for the mean");
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
emitError("Wrong rank for the variance.");
emitError("Wrong rank for the variance");
}
// The output tensor of the same shape as the input.
getResult().setType(getOperand(0).getType());
getResult().setType(X().getType());
}
// TODO:
@ -662,21 +640,21 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
void ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified.
if (!getOperand(1).getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked.");
if (!shape().getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked");
auto inputTensorTy = getOperand(0).getType().cast<RankedTensorType>();
auto shapeTensorTy = getOperand(1).getType().cast<RankedTensorType>();
auto inputTensorTy = data().getType().cast<RankedTensorType>();
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
// Only rank 1 shape tensors are supported.
if (shapeTensorTy.getShape().size() != 1)
emitError("Shape tensor must have rank one.");
emitError("Shape tensor must have rank one");
int64_t outputRank = shapeTensorTy.getShape()[0];
// Shape tensor must have constant shape.
if (outputRank < 0)
emitError("Shape tensor must have constant shape.");
emitError("Shape tensor must have constant shape");
SmallVector<int64_t, 2> dims;
for (int i = 0; i < outputRank; ++i)
@ -692,12 +670,12 @@ void ONNXReshapeOp::inferShapes() {
void ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand().getType().isa<RankedTensorType>())
if (!data().getType().isa<RankedTensorType>())
return;
// Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose).
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
auto arrayTy = data().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims;
auto permutation = ONNXTransposeOp::permAttr();
if (permutation) {
@ -713,14 +691,13 @@ void ONNXTransposeOp::inferShapes() {
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// ReduceMax
void ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
emitError("Shape tensor not ranked");
return;
}
@ -734,7 +711,7 @@ void ONNXReduceMaxOp::inferShapes() {
void ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
emitError("Shape tensor not ranked");
return;
}
@ -748,7 +725,7 @@ void ONNXReduceMinOp::inferShapes() {
void ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
emitError("Shape tensor not ranked");
return;
}
@ -762,7 +739,7 @@ void ONNXReduceProdOp::inferShapes() {
void ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
emitError("Shape tensor not ranked");
return;
}
@ -781,30 +758,31 @@ void ONNXConvNoBiasOp::inferShapes() {
// W: (M x C/group x k1 x k2 x ... x kn)
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
if (!X().getType().isa<RankedTensorType>() ||
!W().getType().isa<RankedTensorType>())
return;
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
auto dataTy = X().getType().cast<RankedTensorType>();
auto weightTy = W().getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape();
auto weightShape = weightTy.getShape();
// Lowest supported convolution is a one dimensional convolution.
if (dataShape.size() < 3)
emitError("Data input shape must be at least (NxCxD1).");
emitError("Data input shape must be at least (NxCxD1)");
// Check that shape of weight and data have same length.
if (dataShape.size() != weightShape.size())
emitError("Weight size not compatible with data size.");
emitError("Weight size not compatible with data size");
// Required attribute auto_pad defaults to NOTSET.
auto autoPad = auto_pad();
// Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
int64_t group =
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (dataShape[1] != (weightShape[1] * group))
emitError("Channel dimension mismatch.");
emitError("Channel dimension mismatch");
// Note: the value of the group attribut only impacts the way the
// computation is carried out and not the actual output size.
@ -834,11 +812,10 @@ void ONNXConvNoBiasOp::inferShapes() {
// argument.
SmallVector<int64_t, 2> kernelDims;
if (auto kernelShape = kernel_shapeAttr()) {
if (kernelShape.getValue().size() != nDims)
emitError("kernel_shape length incompatible with spatial dimensions.");
if (ArrayAttrSize(kernelShape) != nDims)
emitError("kernel_shape length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i)
kernelDims.emplace_back(
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
kernelDims.emplace_back(ArrayAttrIntVal(kernelShape, i));
} else {
for (int i = 0; i < nDims; ++i)
kernelDims.emplace_back(weightShape[i + 2]);
@ -856,11 +833,11 @@ void ONNXConvNoBiasOp::inferShapes() {
// From a dimensionality perspective the kernel size becomes the dilated
// kernel size.
if (auto dilations = dilationsAttr()) {
if (dilations.getValue().size() != nDims)
emitError("dilations length incompatible with spatial dimensions.");
if (ArrayAttrSize(dilations) != nDims)
emitError("dilations length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i)
kernelDims[i] = (kernelDims[i] + 1) *
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1;
kernelDims[i] =
(kernelDims[i] + 1) * ArrayAttrIntVal(dilations, i) - 1;
}
// Subtract kernel dimensions from input data dimensions.
@ -873,16 +850,14 @@ void ONNXConvNoBiasOp::inferShapes() {
// present then pads is considered to be all zeros (no padding).
if (auto pads = padsAttr()) {
// pads consists of two entries for each spatial axis.
if (pads.getValue().size() != 2 * nDims)
emitError("pads size is not twice the spatial size.");
if (ArrayAttrSize(pads) != 2 * nDims)
emitError("pads size is not twice the spatial size");
for (int i = 0; i < nDims; ++i) {
// Padding for beginning of axis.
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
outSpatialDims[i] += p;
outSpatialDims[i] += ArrayAttrIntVal(pads, i);
// Padding for end of axis.
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
outSpatialDims[i] += p;
outSpatialDims[i] += ArrayAttrIntVal(pads, i + nDims);
}
}
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
@ -898,16 +873,15 @@ void ONNXConvNoBiasOp::inferShapes() {
} else if (autoPad == "VALID") {
// No padding
} else {
emitError("Unexpected attribute value for auto_pad.");
emitError("Unexpected attribute value for auto_pad");
}
// Strides
if (auto strides = ONNXConvNoBiasOp::stridesAttr()) {
if (strides.getValue().size() != nDims)
emitError("strides length incompatible with spatial dimensions.");
if (ArrayAttrSize(strides) != nDims)
emitError("strides length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i) {
int64_t stride =
strides.getValue()[i].cast<IntegerAttr>().getInt();
int64_t stride = ArrayAttrIntVal(strides, i);
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
}
}
@ -922,112 +896,108 @@ void ONNXConvNoBiasOp::inferShapes() {
//===----------------------------------------------------------------------===//
// MaxPoolSingleOut
// Infer shape attributes output:
// - auto_pad set to NOTSET;
// - dilations, strides: set to 1 if not defined by user;
// - pads: set to proper value, 0 if not defined by user.
void ONNXMaxPoolSingleOutOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>())
return;
auto builder = mlir::Builder(this->getContext());
// 1) get shape of input
// 1) Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape();
auto xRank = xShape.size();
// 2) analyse parameters
// get kernel sizes from kernel_shape attribute
// 2) Analyse parameters. Get kernel sizes from kernel_shape attribute.
auto kernelShape = kernel_shape();
if (!kernelShape)
emitError("kernel_shape is a mandatory attribute for which there is no default.");
auto kernelShapeArray = kernelShape.getValue();
auto kernelRank = kernelShape.size();
emitError(
"kernel_shape is a mandatory attribute for which there is no default");
auto kernelRank = ArrayAttrSize(kernelShape);
if (kernelRank > xRank)
emitError("kernel_shape spatial dimension is too large.");
emitError("kernel_shape spatial dimension is too large");
auto kernelOffset = xRank - kernelRank;
// ceil mode
// Ceil mode.
auto ceilMode = ceil_mode().getSExtValue();
// dilatation
SmallVector<int64_t, 4> actualDilations;
// Dilatation.
auto dilationsOpt = dilations();
if (dilationsOpt.hasValue()) {
auto dilationsArray = dilationsOpt.getValue().getValue(); // opt -> attr -> array
if (dilationsArray.size() != kernelRank)
emitError("dialation rank is not the same as the spatial rank.");
// fill in the actual values
if (ArrayAttrSize(dilationsOpt) != kernelRank)
emitError("dialation rank is not the same as the spatial rank");
// Test values.
for (int i = 0; i < kernelRank; ++i) {
int64_t d = (dilationsArray[i]).cast<IntegerAttr>().getInt();
if (d < 1)
emitError("dialation value must be nonzero positive.");
actualDilations.emplace_back(d);
if (ArrayAttrIntVal(dilationsOpt, i) < 1)
emitError("dialation value must be nonzero positive");
}
} else {
for(int i=0; i < kernelRank; ++i) {
actualDilations.emplace_back(1);
}
// Default dilatation is needed.
SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
// Convert to ArrayRef, then build attribute, then store attribute.
ArrayRef<int64_t> defaultRefs(defaultVals);
auto defaultAttr = builder.getI64ArrayAttr(defaultRefs);
dilationsAttr(defaultAttr);
dilationsOpt = dilations();
}
// storage order
// Storage order.
auto storageOrder = storage_order().getSExtValue();
if (storageOrder != 0)
emitError("column major storage order not supported at this time");
// strides
SmallVector<int64_t, 4> actualStrides;
// Strides.
auto stridesOpt = strides();
if (stridesOpt.hasValue()) {
auto stridesArray = stridesOpt.getValue().getValue();
if (stridesArray.size() != kernelRank)
emitError("strides rank is not the same as the spatial rank.");
// fill in the actual values
if (ArrayAttrSize(stridesOpt) != kernelRank)
emitError("strides rank is not the same as the spatial rank");
// Check values.
for (int i = 0; i < kernelRank; ++i) {
int64_t s = (stridesArray[i]).cast<IntegerAttr>().getInt();
if (s < 1)
emitError("strides value must be nonzero positive.");
actualStrides.emplace_back(s);
if (ArrayAttrIntVal(stridesOpt, i) < 1)
emitError("strides value must be nonzero positive");
}
} else {
for(int i=0; i < kernelRank; ++i) {
actualStrides.emplace_back(1);
}
SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
// Convert to ArrayRef, then build attribute, then store attribute.
ArrayRef<int64_t> defaultRefs(defaultVals);
auto defaultAttr = builder.getI64ArrayAttr(defaultRefs);
stridesAttr(defaultAttr);
stridesOpt = strides();
}
// now try to find padding, getting auto_pad attribute first
// Now try to find padding, getting auto_pad attribute first.
auto autoPad = auto_pad();
// and then investigate the various different cases
SmallVector<int64_t, 4> actualPads;
auto defaultPads = false;
// And then investigate the various different cases.
SmallVector<int64_t, 4> actualPads(2 * kernelRank, 0);
if (autoPad == "NOTSET") {
auto padsOpt = pads();
if (padsOpt.hasValue()) {
auto padsArray = padsOpt.getValue().getValue();
// pads consists of two entries for each spatial axis.
if (padsArray.size() != 2 * kernelRank)
emitError("pads rank is not twice the spatial rank.");
// fill in the actual values
// Pads consists of two entries for each spatial axis.
if (ArrayAttrSize(padsOpt) != 2 * kernelRank)
emitError("pads rank is not twice the spatial rank");
// Check values
for (int i = 0; i < 2 * kernelRank; ++i) {
int64_t p = (padsArray[i]).cast<IntegerAttr>().getInt();
int64_t p = ArrayAttrIntVal(padsOpt, i);
if (p < 0)
emitError("pads value must be nonnegative.");
actualPads.emplace_back(p);
emitError("pads value must be nonnegative");
actualPads[i] = p;
}
} else {
// pads are not defined, default to value 0
defaultPads = true;
}
} else if (autoPad == "VALID") {
defaultPads = true;
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
// init pad with zero
for(int i=0; i<2*kernelRank; ++i) {
actualPads.emplace_back(0);
}
for (int i = 0; i < kernelRank; ++i) {
auto inputSpatialShape = xShape[kernelOffset + i];
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt();
auto dilations = actualDilations[i];
auto strideSpatialShape = actualStrides[i];
int64_t outputSpatialShape = ceil((1.0 * inputSpatialShape) /
(1.0 * strideSpatialShape));
auto kernelSpatialShape = ArrayAttrIntVal(kernelShape, i);
auto dilations = ArrayAttrIntVal(dilationsOpt, i);
auto strideSpatialShape = ArrayAttrIntVal(stridesOpt, i);
int64_t outputSpatialShape =
ceil((1.0 * inputSpatialShape) / (1.0 * strideSpatialShape));
auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +
((kernelSpatialShape - 1) * dilations + 1) - inputSpatialShape;
((kernelSpatialShape - 1) * dilations + 1) -
inputSpatialShape;
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
if (sumOfPad % 2 != 0) {
if (autoPad == "SAME_UPPER") {
@ -1037,27 +1007,27 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
}
}
}
} else {
emitError("auto_pad of unknown / unsupported value.");
}
// handle case where default pad values must be used
if (defaultPads) {
for(int i=0; i<2*kernelRank; ++i) {
actualPads.emplace_back(0);
} else if (autoPad != "VALID") {
emitError("auto_pad of unknown / unsupported value");
}
// Set pads values in attributes.
{
ArrayRef<int64_t> defaultRefs(actualPads);
auto defaultAttr = builder.getI64ArrayAttr(defaultRefs);
padsAttr(defaultAttr);
auto defaultAutoPadAttr = builder.getStringAttr("NOTSET");
auto_padAttr(defaultAutoPadAttr);
}
// initialize output shape
// Initialize output shape.
SmallVector<int64_t, 4> yShape(xShape.begin(), xShape.end());
// for all kernel dimensions
// Process for all kernel dimensions.
for (int i = 0; i < kernelRank; ++i) {
auto inputSpatialShape = xShape[kernelOffset + i];
auto padShape = actualPads[i] + actualPads[kernelRank + i];
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt();
auto dilations = actualDilations[i];
auto strideSpatialShape = actualStrides[i];
///output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] -
// ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1)
auto kernelSpatialShape = ArrayAttrIntVal(kernelShape, i);
auto dilations = ArrayAttrIntVal(dilationsOpt, i);
auto strideSpatialShape = ArrayAttrIntVal(stridesOpt, i);
double numerator = inputSpatialShape + padShape -
((kernelSpatialShape - 1) * dilations + 1);
double denominator = strideSpatialShape;
@ -1069,7 +1039,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
}
yShape[kernelOffset + i] = res;
}
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
auto arrayTy = X().getType().cast<RankedTensorType>();
getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType()));
}
@ -1152,10 +1122,10 @@ void ONNXPadConstantValuePadOp::inferShapes(){
// Unsqueeze
void ONNXUnsqueezeOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>())
if (!data().getType().isa<RankedTensorType>())
return;
auto operandTy = getOperand().getType().cast<RankedTensorType>();
auto operandTy = data().getType().cast<RankedTensorType>();
int inRank = operandTy.getRank();
ArrayAttr axisAttrs = axesAttr();
@ -1171,10 +1141,10 @@ void ONNXUnsqueezeOp::inferShapes() {
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.emplace_back(axis);
else
emitError("Duplicated axes.");
emitError("Duplicated axes");
}
} else {
emitError("Axes attribute is required.");
emitError("Axes attribute is required");
}
SmallVector<int64_t, 4> dims;

View File

@ -1,7 +1,8 @@
//********************************************************
// Warning: Do not modify this file directly
// This file is automatically generated via script
// Details can be found in doc/readonnxdefs.md
// This file is generated on UTC-02/24/2020, 06:44:13.
// Do not modify this file directly.
// This file is automatically generated via script.
// Details can be found in doc/readonnxdefs.md .
//********************************************************
def ONNXAbsOp:ONNX_Op<"Abs",
@ -213,10 +214,10 @@ def ONNXBatchNormalizationOp:ONNX_Op<"BatchNormalization",
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
DefaultValuedAttr<F32Attr, "0.9">:$momentum);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$out_mean,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$out_var,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$saved_mean,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$saved_var);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$out_mean,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$out_var,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$saved_mean,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$saved_var);
}
def ONNXBitShiftOp:ONNX_Op<"BitShift",
@ -224,12 +225,12 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift",
let summary = "ONNX BitShift operation";
let description = [{
"Bitwise shift operator performs element-wise operation. For each input element, if the"
" attribute "direction" is "RIGHT", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute "direction""
" is "LEFT", bits of binary representation moves toward the left side, which results the"
" attribute \"direction\" is \"RIGHT\", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute \"direction\""
" is \"LEFT\", bits of binary representation moves toward the left side, which results the"
" increase of its actual value. The input X is the tensor to be shifted and another input"
" Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with"
" Y specifies the amounts of shifting. For example, if \"direction\" is \"Right\", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If \"direction\" is \"LEFT\" with"
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
" "
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
@ -251,15 +252,15 @@ def ONNXCastOp:ONNX_Op<"Cast",
"the converted type. The 'to' argument must be one of the data types specified"
"in the 'DataType' enum field in the TensorProto message."
""
"Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations"
"(e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may"
"Casting from string tensor in plain (e.g., \"3.14\" and \"1000\") and scientific numeric representations"
"(e.g., \"1e-5\" and \"1E8\") to float types is supported. For example, converting string \"100.5\" to an integer may"
"result 100. There are some string literals reserved for special floating-point values;"
""+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as "314.15926") would be used. "
"Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior."
"\"+INF\" (and \"INF\"), \"-INF\", and \"NaN\" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match \"+INF\" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to \"INF\" and \"NaN\". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as \"314.15926\") would be used. "
"Converting non-numerical-literal string such as \"Hello World!\" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as \"2.718\", to INT is an undefined behavior."
""
"Conversion from a numerical type to any numerical type is always allowed."
"User must be aware of precision loss and value change caused by range difference between two types."
@ -292,8 +293,8 @@ def ONNXClipOp:ONNX_Op<"Clip",
"numeric_limits::lowest() and numeric_limits::max(), respectively."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$min,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$max);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$min,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$max);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
}
@ -370,7 +371,7 @@ def ONNXConvOp:ONNX_Op<"Conv",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
OptionalAttr<I64ArrayAttr>:$dilations,
DefaultValuedAttr<I64Attr, "1">:$group,
@ -389,8 +390,8 @@ def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$w,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$x_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$w_zero_point,
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
OptionalAttr<I64ArrayAttr>:$dilations,
DefaultValuedAttr<I64Attr, "1">:$group,
@ -421,7 +422,7 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
OptionalAttr<I64ArrayAttr>:$dilations,
DefaultValuedAttr<I64Attr, "1">:$group,
@ -534,7 +535,7 @@ def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_scale,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_zero_point);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$x_zero_point);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$y);
}
@ -579,7 +580,7 @@ def ONNXDropoutOp:ONNX_Op<"Dropout",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
DefaultValuedAttr<F32Attr, "0.5">:$ratio);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$mask);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$mask);
}
def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear",
@ -817,9 +818,9 @@ def ONNXGRUOp:ONNX_Op<"GRU",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$R,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_h,
OptionalAttr<F32ArrayAttr>:$activation_alpha,
OptionalAttr<F32ArrayAttr>:$activation_beta,
OptionalAttr<StrArrayAttr>:$activations,
@ -827,8 +828,8 @@ def ONNXGRUOp:ONNX_Op<"GRU",
DefaultValuedAttr<StrAttr, "forward">:$direction,
OptionalAttr<I64Attr>:$hidden_size,
DefaultValuedAttr<I64Attr, "0">:$linear_before_reset);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y_h);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_h);
}
def ONNXGatherOp:ONNX_Op<"Gather",
@ -1042,6 +1043,7 @@ def ONNXGatherNDOp:ONNX_Op<"GatherND",
def ONNXGemmOp:ONNX_Op<"Gemm",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX Gemm operation";
let description = [{
"General Matrix multiplication:"
@ -1060,7 +1062,7 @@ def ONNXGemmOp:ONNX_Op<"Gemm",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$C,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$C,
DefaultValuedAttr<F32Attr, "1.0">:$alpha,
DefaultValuedAttr<F32Attr, "1.0">:$beta,
DefaultValuedAttr<I64Attr, "0">:$transA,
@ -1332,11 +1334,11 @@ def ONNXLSTMOp:ONNX_Op<"LSTM",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$R,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_c,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$P,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_h,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_c,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$P,
OptionalAttr<F32ArrayAttr>:$activation_alpha,
OptionalAttr<F32ArrayAttr>:$activation_beta,
OptionalAttr<StrArrayAttr>:$activations,
@ -1344,9 +1346,9 @@ def ONNXLSTMOp:ONNX_Op<"LSTM",
DefaultValuedAttr<StrAttr, "forward">:$direction,
OptionalAttr<I64Attr>:$hidden_size,
DefaultValuedAttr<I64Attr, "0">:$input_forget);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y_h,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y_c);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_h,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_c);
}
def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu",
@ -1430,24 +1432,24 @@ def ONNXLoopOp:ONNX_Op<"Loop",
""
" Operator inputs defined as (max_trip_count, condition_var)."
""
" input ("", ""):"
" input (\"\", \"\"):"
" for (int i=0; ; ++i) {"
" cond = ... // Note this value is ignored, but is required in the body"
" }"
""
" input ("", cond) // Note this is analogous to a while loop"
" input (\"\", cond) // Note this is analogous to a while loop"
" bool cond = ...;"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
" input ("", 1) // Note this is analogous to a do-while loop"
" input (\"\", 1) // Note this is analogous to a do-while loop"
" bool cond = true"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
" input (trip_count, "") // Note this is analogous to a for loop"
" input (trip_count, \"\") // Note this is analogous to a for loop"
" int trip_count = ..."
" for (int i=0; i < trip_count; ++i) {"
" cond = ...; // ignored"
@ -1473,15 +1475,15 @@ def ONNXLoopOp:ONNX_Op<"Loop",
" }"
""
" graph body-net ("
" %i[INT32, scalar] // iteration number"
" %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used"
" %b_in[INT32, scalar] // incoming value of loop-carried-dependency b"
" %i[INT32, scalar]"
" %keepgoing[BOOL, scalar]"
" %b[INT32, scalar]"
" ) {"
" %my_local = Add(%a, %b_in)"
" %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b"
" %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition"
" %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated"
" return %keepgoing_out, %b_out, %user_defined_val"
" %my_local = Add(%a, %b)"
" %b_out = Sub(%a, %b)"
" %keepgoing_out = Greater(%my_local, %b_out)"
" %user_defined_vals = Add(%b, %b)"
" return %keepgoing_out, %b_out, %user_defined_vals"
" }"
""
"*Sample equivalent C code*"
@ -1496,51 +1498,31 @@ def ONNXLoopOp:ONNX_Op<"Loop",
" const int max_trip_count = 10; // Analogous to input M"
" int user_defined_vals[]; // Imagine this is resizable"
" /* End implicitly-defined code */"
" /* initialize loop-carried variables and scan-output variables */"
" bool keepgoing_out = keepgoing"
" int b_out = b"
""
" for (int i=0; i < max_trip_count && keepgoing_out; ++i) {"
" /* Implicitly-defined code: bind actual parameter values"
" to formal parameter variables of loop-body */"
" bool keepgoing_in = keepgoing_out; "
" bool b_in = b_out;"
""
" for (int i=0; i < max_trip_count && keepgoing; ++i) {"
" /* User-defined code (loop body) */"
" int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine"
" b_out = a - b_in;"
" keepgoing_out = my_local > b_out; "
" user_defined_val = b_in + b_in; // b_in and b_out are different variables"
" int my_local = a + b; // Reading values in the enclosing scope is fine"
" b = a - b; // writes fine if we specify b as a loop-carried dependency"
" keepgoing = my_local > b; // keepgoing is a loop-carried dependency"
" user_defined_vals[i] = b + b;"
" /* End user-defined code */"
""
" /* Implicitly defined-code */"
" user_defined_vals[i] = user_defined_val // accumulate scan-output values"
" }"
" // int t = my_local; // Can't do this. my_local is not accessible here."
" // my_local = 123; // Can't do this. my_local was defined in the the body"
""
" // The values below are bound to the output variables of the loop and therefore accessible"
" // b_out; user_defined_vals; keepgoing_out;"
" // These below values are live-out from the loop and therefore accessible"
" b_out; user_defined_vals; keepgoing_out;"
" }"
""
"There are several things of note in this code snippet:"
""
"1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can"
"1) Values from the enclosing scope (i.e. variable a here) are in scope and can"
" be referenced in the inputs of the loop."
"2) Any values computed in the loop body that needs to be used in a subsequent"
" iteration or after the loop are modelled using a pair of variables in the loop-body,"
" consisting of an input variable (eg., b_in) and an output variable (eg., b_out)."
" These are referred to as loop-carried dependences. The loop operation node"
" supplies the input value of the input variable for the first iteration, and"
" returns the output value of the output variable produced by the final"
" iteration."
"3) Scan_output variables are used to implicitly concatenate values computed across"
" all the iterations. In the above example, the value of user_defined_val computed"
" over all iterations are concatenated and returned as the value of user_defined_vals"
" after the loop."
"4) Values created in the body cannot be accessed in the enclosing scope,"
" except using the mechanism described above."
"2) Any variables which you wish to make available in the enclosing scope (i.e."
" the variables b and keepgoing) must be declared as either loop-carried"
" dependencies (both at the op inputs and output and at the body net input and"
" output) or scan_outputs."
"3) Values created in the body cannot be accessed in the enclosing scope."
""
"Note that the semantics of this op support "diagonal" or "wavefront" execution."
"Note that the semantics of this op support \"diagonal\" or \"wavefront\" execution."
"(See Step 3 here for an example:"
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
"Frontends should emit multi-layer RNNs as a series of While operators (with"
@ -1548,8 +1530,8 @@ def ONNXLoopOp:ONNX_Op<"Loop",
"the scan_outputs from the previous layer, possibly going through several"
"point-wise operators (e.g. dropout, residual connections, linear layer)."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$M,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$cond,
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$M,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$cond,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$v_initial,
AnyAttr:$body);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$v_final_and_scan_outputs);
@ -1606,8 +1588,8 @@ def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$a_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$b_zero_point);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$a_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$b_zero_point);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
}
@ -1666,7 +1648,7 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool",
DefaultValuedAttr<I64Attr, "0">:$storage_order,
OptionalAttr<I64ArrayAttr>:$strides);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Indices);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Indices);
}
def ONNXMaxRoiPoolOp:ONNX_Op<"MaxRoiPool",
@ -1709,7 +1691,7 @@ def ONNXMaxUnpoolOp:ONNX_Op<"MaxUnpool",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$I,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_shape,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$output_shape,
I64ArrayAttr:$kernel_shape,
OptionalAttr<I64ArrayAttr>:$pads,
OptionalAttr<I64ArrayAttr>:$strides);
@ -1841,9 +1823,9 @@ def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$boxes,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$scores,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$max_output_boxes_per_class,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$iou_threshold,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$score_threshold,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$max_output_boxes_per_class,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$iou_threshold,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$score_threshold,
DefaultValuedAttr<I64Attr, "0">:$center_point_box);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$selected_indices);
}
@ -2018,7 +2000,7 @@ def ONNXPadOp:ONNX_Op<"Pad",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$pads,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$constant_value,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$constant_value,
DefaultValuedAttr<StrAttr, "constant">:$mode);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
}
@ -2055,7 +2037,7 @@ def ONNXQLinearConvOp:ONNX_Op<"QLinearConv",
AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
OptionalAttr<I64ArrayAttr>:$dilations,
DefaultValuedAttr<I64Attr, "1">:$group,
@ -2099,7 +2081,7 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$y_zero_point);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$y);
}
@ -2172,17 +2154,17 @@ def ONNXRNNOp:ONNX_Op<"RNN",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$R,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_h,
OptionalAttr<F32ArrayAttr>:$activation_alpha,
OptionalAttr<F32ArrayAttr>:$activation_beta,
DefaultValuedAttr<StrArrayAttr, "{\"Tanh\", \"Tanh\"}">:$activations,
OptionalAttr<F32Attr>:$clip,
DefaultValuedAttr<StrAttr, "forward">:$direction,
OptionalAttr<I64Attr>:$hidden_size);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y_h);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_h);
}
def ONNXRandomNormalOp:ONNX_Op<"RandomNormal",
@ -2545,12 +2527,12 @@ def ONNXResizeOp:ONNX_Op<"Resize",
let description = [{
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
"Each dimension value of the output tensor is:"
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified."
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$roi,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$scales,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$sizes,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sizes,
DefaultValuedAttr<StrAttr, "half_pixel">:$coordinate_transformation_mode,
DefaultValuedAttr<F32Attr, "-0.75">:$cubic_coeff_a,
DefaultValuedAttr<I64Attr, "0">:$exclude_outside,
@ -3044,7 +3026,7 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase",
"'position' is optional, by default it erases the last tensor from 'input_sequence'."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$position);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$position);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
}
@ -3060,7 +3042,7 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$position);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$position);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
}
@ -3194,8 +3176,8 @@ def ONNXSliceOp:ONNX_Op<"Slice",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$starts,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$ends,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$axes,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$steps);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$axes,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$steps);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
}
@ -3269,7 +3251,7 @@ def ONNXSplitOp:ONNX_Op<"Split",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input,
DefaultValuedAttr<I64Attr, "0">:$axis,
OptionalAttr<I64ArrayAttr>:$split);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$outputs);
let results = (outs Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$outputs);
}
def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
@ -3288,7 +3270,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
"dimension size of input tensor on 'axis'."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$split,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$split,
DefaultValuedAttr<I64Attr, "0">:$axis,
DefaultValuedAttr<I64Attr, "1">:$keepdims);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
@ -3327,9 +3309,9 @@ def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer",
"StringNormalization performs string operations for basic cleaning."
"This operator has only one input (denoted by X) and only one output"
"(denoted by Y). This operator first examines the elements in the X,"
"and removes elements specified in "stopwords" attribute."
"and removes elements specified in \"stopwords\" attribute."
"After removing stop words, the intermediate result can be further lowercased,"
"uppercased, or just returned depending the "case_change_action" attribute."
"uppercased, or just returned depending the \"case_change_action\" attribute."
"This operator only accepts [C]- and [1, C]-tensor."
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
"if input shape is [C] and shape [1, 1] if input shape is [1, C]."
@ -3412,8 +3394,8 @@ def ONNXTfIdfVectorizerOp:ONNX_Op<"TfIdfVectorizer",
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
"Note that we may consider all skips up to S when generating the n-grams."
""
"The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF","
"The examples used above are true if mode is \"TF\". If mode is \"IDF\", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is \"TFIDF\","
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
""
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
@ -3470,9 +3452,9 @@ def ONNXTopKOp:ONNX_Op<"TopK",
" contains the indices of the top k elements (original indices from the input"
" tensor)."
""
"If "largest" is 1 (the default value) then the k largest elements are returned."
"If "sorted" is 1 (the default value) then the resulting k elements will be sorted."
"If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined."
"If \"largest\" is 1 (the default value) then the k largest elements are returned."
"If \"sorted\" is 1 (the default value) then the resulting k elements will be sorted."
"If \"sorted\" is 0, order of returned 'Values' and 'Indices' are undefined."
""
"Given two equivalent values, this operator uses the indices along the axis as"
" a tiebreaker. That is, the element with the lower index will appear first."
@ -3509,7 +3491,7 @@ def ONNXUniqueOp:ONNX_Op<"Unique",
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
"The first output tensor 'Y' contains all unique values or subtensors of the input. "
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. \". "
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
""
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
@ -3583,9 +3565,9 @@ def ONNXUniqueOp:ONNX_Op<"Unique",
OptionalAttr<I64Attr>:$axis,
DefaultValuedAttr<I64Attr, "1">:$sorted);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$inverse_indices,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$counts);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$indices,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$inverse_indices,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$counts);
}
def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
@ -3652,3 +3634,4 @@ def ONNXXorOp:ONNX_Op<"Xor",
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C);
}

View File

@ -127,6 +127,10 @@ int main(int argc, char *argv[]) {
if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass());
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createLowerKrnlPass());
}

View File

@ -28,6 +28,11 @@ void ONNXAddOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<MulAddToGemmOptPattern>(context);
}
void ONNXGemmOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<FuseGemmFollowedByAddition>(context);
}
/// on the ONNXIdentityOp.
void ONNXIdentityOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {

View File

@ -26,6 +26,7 @@ include "dialect/onnx/onnx.td"
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite
@ -41,6 +42,11 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
// onnx.add(onnx.Gemm(%X, %Y, None), %Z) = onnx.Gemm(%X, %Y, %Z)
def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none, $alpha, $beta, $transA, $transB), $bias),
(ONNXGemmOp $m1, $m2, $bias, $alpha, $beta, $transA, $transB),
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2), (HasNoneType $none)]>;
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
(replaceWithValue $arg)>;

View File

@ -118,7 +118,6 @@ public:
op->getName().getStringRef() != "onnx.Identity" &&
op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.GemmNoBias" &&
op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose" &&
op->getName().getStringRef() != "onnx.ReduceMax" &&

View File

@ -101,3 +101,14 @@ func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
// CHECK-NEXT: return %1 : tensor<*xf32>
}
//CHECK-LABEL: @test_gemm_add_fusion(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> {
func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32>
%1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32>
// return [[GEMM]] : tensor<*xf32>
}

View File

@ -806,35 +806,6 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso
// CHECK: }
}
func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
%0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_gemm_no_bias
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
// CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
// CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
// CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
// CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) {
// CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32>
// CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
// CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
// CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: }
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
// CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<10x10xf32>
// CHECK: }
}
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()

View File

@ -6,7 +6,7 @@ func @test_default_maxpoolsingleout(%arg0 : tensor<5x5x32x32xf32>) -> tensor<*xf
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_default_maxpoolsingleout
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "VALID", ceil_mode = 0 : i64, kernel_shape = [3, 3], pads = [1, 1, 1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32>
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32>
// CHECK: return [[RES]] : tensor<5x5x30x30xf32>
@ -16,7 +16,7 @@ func @test_default_maxpoolsingleout_defpad(%arg0 : tensor<5x5x32x32xf32>) -> ten
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_default_maxpoolsingleout_defpad
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [3, 3]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32>
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32>
// CHECK: return [[RES]] : tensor<5x5x30x30xf32>
@ -26,7 +26,7 @@ func @test_default_maxpoolsingleout_pad(%arg0 : tensor<5x5x32x32xf32>) -> tensor
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_default_maxpoolsingleout_pad
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [3, 3], pads = [1, 1, 1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32>
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32>
// CHECK: return [[RES]] : tensor<5x5x32x32xf32>
@ -36,7 +36,7 @@ func @test_default_maxpoolsingleout_pad_nonunif(%arg0 : tensor<5x5x32x32xf32>) -
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_default_maxpoolsingleout_pad_nonunif
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [2, 1, 1, 0]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x31x31xf32>
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [5, 3], pads = [2, 1, 1, 0], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x31x31xf32>
// CHECK: return [[RES]] : tensor<5x5x31x31xf32>
@ -46,7 +46,7 @@ func @test_default_maxpoolsingleout_strides(%arg0 : tensor<5x5x32x32xf32>) -> te
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_default_maxpoolsingleout_strides
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x16x16xf32>
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x16x16xf32>
// CHECK: return [[RES]] : tensor<5x5x16x16xf32>
@ -56,7 +56,7 @@ func @test_default_maxpoolsingleout_strides_nonunifpad(%arg0 : tensor<5x5x30x32x
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_default_maxpoolsingleout_strides_nonunifpad
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32>
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32>
// CHECK: return [[RES]] : tensor<5x5x15x16xf32>
@ -66,7 +66,7 @@ func @test_default_maxpoolsingleout_strides_nonunifpad_ceil(%arg0 : tensor<5x5x3
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_default_maxpoolsingleout_strides_nonunifpad_ceil
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 1 : i64, kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x16x16xf32>
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 1 : i64, dilations = [1, 1], kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x16x16xf32>
// CHECK: return [[RES]] : tensor<5x5x16x16xf32>
@ -76,7 +76,7 @@ func @test_default_maxpoolsingleout_strides_dilatation(%arg0 : tensor<5x5x8x8xf3
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_default_maxpoolsingleout_strides_dilatation
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [2, 2], kernel_shape = [2, 2], strides = [3, 3]} : (tensor<5x5x8x8xf32>) -> tensor<5x5x2x2xf32>
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [2, 2], kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [3, 3]} : (tensor<5x5x8x8xf32>) -> tensor<5x5x2x2xf32>
// CHECK: return [[RES]] : tensor<5x5x2x2xf32>
/// Test the default behavior of Max Pool with dilatation
@ -85,7 +85,7 @@ func @test_default_maxpoolsingleout_upper(%arg0 : tensor<5x5x16x13xf32>) -> tens
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_default_maxpoolsingleout_upper
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "SAME_UPPER", ceil_mode = 0 : i64, kernel_shape = [4, 4], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32>
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [4, 4], pads = [0, 1, 0, 2], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32>
// CHECK: return [[RES]] : tensor<5x5x4x4xf32>
@ -95,6 +95,6 @@ func @test_default_maxpoolsingleout_lower(%arg0 : tensor<5x5x16x13xf32>) -> tens
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_default_maxpoolsingleout_lower
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "SAME_LOWER", ceil_mode = 0 : i64, kernel_shape = [4, 4], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32>
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [4, 4], pads = [0, 2, 0, 1], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32>
// CHECK: return [[RES]] : tensor<5x5x4x4xf32>

2
third_party/onnx vendored

@ -1 +1 @@
Subproject commit 1439eab5542c625bb3da49860f0cd68c3eafdc18
Subproject commit 553df22c67bee5f0fe6599cff60f1afc6748c635