Merge remote-tracking branch 'upstream/master' into shapeinference-pad
This commit is contained in:
commit
4079ee1f26
|
@ -38,7 +38,7 @@ jobs:
|
||||||
- run:
|
- run:
|
||||||
name: Run End-To-End Tests
|
name: Run End-To-End Tests
|
||||||
command: |
|
command: |
|
||||||
sudo pip install -q onnx
|
sudo pip install -q -e ./ONNF/third_party/onnx
|
||||||
cd ONNF/build
|
cd ONNF/build
|
||||||
cmake --build . --target run-onnx-backend-test
|
cmake --build . --target run-onnx-backend-test
|
||||||
- run:
|
- run:
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
BasedOnStyle: LLVM
|
BasedOnStyle: LLVM
|
||||||
AlwaysBreakTemplateDeclarations: Yes
|
AlwaysBreakTemplateDeclarations: Yes
|
||||||
|
AlignAfterOpenBracket: DontAlign
|
||||||
|
|
|
@ -30,3 +30,145 @@
|
||||||
*.exe
|
*.exe
|
||||||
*.out
|
*.out
|
||||||
*.app
|
*.app
|
||||||
|
|
||||||
|
# Filesystem
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# The following .gitignore content is taken from
|
||||||
|
# https://github.com/github/gitignore/blob/master/Python.gitignore
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
|
@ -327,10 +327,10 @@ ONNX BatchNormalization operation
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values
|
||||||
1. `out_mean`: memref of any type values or tensor of any type values
|
1. `out_mean`: memref of any type values or tensor of any type values or none type
|
||||||
1. `out_var`: memref of any type values or tensor of any type values
|
1. `out_var`: memref of any type values or tensor of any type values or none type
|
||||||
1. `saved_mean`: memref of any type values or tensor of any type values
|
1. `saved_mean`: memref of any type values or tensor of any type values or none type
|
||||||
1. `saved_var`: memref of any type values or tensor of any type values
|
1. `saved_var`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
|
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
|
||||||
ONNX BatchNormalization operation in test mode
|
ONNX BatchNormalization operation in test mode
|
||||||
|
@ -375,12 +375,12 @@ ONNX BitShift operation
|
||||||
|
|
||||||
|
|
||||||
"Bitwise shift operator performs element-wise operation. For each input element, if the"
|
"Bitwise shift operator performs element-wise operation. For each input element, if the"
|
||||||
" attribute "direction" is "RIGHT", this operator moves its binary representation toward"
|
" attribute \"direction\" is \"RIGHT\", this operator moves its binary representation toward"
|
||||||
" the right side so that the input value is effectively decreased. If the attribute "direction""
|
" the right side so that the input value is effectively decreased. If the attribute \"direction\""
|
||||||
" is "LEFT", bits of binary representation moves toward the left side, which results the"
|
" is \"LEFT\", bits of binary representation moves toward the left side, which results the"
|
||||||
" increase of its actual value. The input X is the tensor to be shifted and another input"
|
" increase of its actual value. The input X is the tensor to be shifted and another input"
|
||||||
" Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4],"
|
" Y specifies the amounts of shifting. For example, if \"direction\" is \"Right\", X is [1, 4],"
|
||||||
" and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with"
|
" and S is [1, 1], the corresponding output Z would be [0, 2]. If \"direction\" is \"LEFT\" with"
|
||||||
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
|
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
|
||||||
" "
|
" "
|
||||||
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
|
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
|
||||||
|
@ -413,15 +413,15 @@ ONNX Cast operation
|
||||||
"the converted type. The 'to' argument must be one of the data types specified"
|
"the converted type. The 'to' argument must be one of the data types specified"
|
||||||
"in the 'DataType' enum field in the TensorProto message."
|
"in the 'DataType' enum field in the TensorProto message."
|
||||||
""
|
""
|
||||||
"Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations"
|
"Casting from string tensor in plain (e.g., \"3.14\" and \"1000\") and scientific numeric representations"
|
||||||
"(e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may"
|
"(e.g., \"1e-5\" and \"1E8\") to float types is supported. For example, converting string \"100.5\" to an integer may"
|
||||||
"result 100. There are some string literals reserved for special floating-point values;"
|
"result 100. There are some string literals reserved for special floating-point values;"
|
||||||
""+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively."
|
"\"+INF\" (and \"INF\"), \"-INF\", and \"NaN\" are positive infinity, negative infinity, and not-a-number, respectively."
|
||||||
"Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly,"
|
"Any string which can exactly match \"+INF\" in a case-insensitive way would be mapped to positive infinite. Similarly,"
|
||||||
"this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors"
|
"this case-insensitive rule is applied to \"INF\" and \"NaN\". When casting from numeric tensors"
|
||||||
"to string tensors, plain floating-point representation (such as "314.15926") would be used. "
|
"to string tensors, plain floating-point representation (such as \"314.15926\") would be used. "
|
||||||
"Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases "
|
"Converting non-numerical-literal string such as \"Hello World!\" is an undefined behavior. Cases "
|
||||||
"of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior."
|
"of converting string representing floating-point arithmetic value, such as \"2.718\", to INT is an undefined behavior."
|
||||||
""
|
""
|
||||||
"Conversion from a numerical type to any numerical type is always allowed."
|
"Conversion from a numerical type to any numerical type is always allowed."
|
||||||
"User must be aware of precision loss and value change caused by range difference between two types."
|
"User must be aware of precision loss and value change caused by range difference between two types."
|
||||||
|
@ -476,8 +476,8 @@ ONNX Clip operation
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
|
||||||
1. `input`: memref of any type values or tensor of any type values
|
1. `input`: memref of any type values or tensor of any type values
|
||||||
1. `min`: memref of any type values or tensor of any type values
|
1. `min`: memref of any type values or tensor of any type values or none type
|
||||||
1. `max`: memref of any type values or tensor of any type values
|
1. `max`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -618,8 +618,8 @@ ONNX ConvInteger operation
|
||||||
|
|
||||||
1. `x`: memref of any type values or tensor of any type values
|
1. `x`: memref of any type values or tensor of any type values
|
||||||
1. `w`: memref of any type values or tensor of any type values
|
1. `w`: memref of any type values or tensor of any type values
|
||||||
1. `x_zero_point`: memref of any type values or tensor of any type values
|
1. `x_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
1. `w_zero_point`: memref of any type values or tensor of any type values
|
1. `w_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -678,7 +678,7 @@ ONNX Conv operation
|
||||||
|
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `W`: memref of any type values or tensor of any type values
|
1. `W`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -720,7 +720,7 @@ ONNX ConvTranspose operation
|
||||||
|
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `W`: memref of any type values or tensor of any type values
|
1. `W`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -884,7 +884,7 @@ ONNX DequantizeLinear operation
|
||||||
|
|
||||||
1. `x`: memref of any type values or tensor of any type values
|
1. `x`: memref of any type values or tensor of any type values
|
||||||
1. `x_scale`: memref of any type values or tensor of any type values
|
1. `x_scale`: memref of any type values or tensor of any type values
|
||||||
1. `x_zero_point`: memref of any type values or tensor of any type values
|
1. `x_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -964,7 +964,7 @@ ONNX Dropout operation
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `output`: memref of any type values or tensor of any type values
|
1. `output`: memref of any type values or tensor of any type values
|
||||||
1. `mask`: memref of any type values or tensor of any type values
|
1. `mask`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp)
|
### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp)
|
||||||
ONNX DynamicQuantizeLinear operation
|
ONNX DynamicQuantizeLinear operation
|
||||||
|
@ -1297,9 +1297,9 @@ ONNX GRU operation
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `W`: memref of any type values or tensor of any type values
|
1. `W`: memref of any type values or tensor of any type values
|
||||||
1. `R`: memref of any type values or tensor of any type values
|
1. `R`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
1. `sequence_lens`: memref of any type values or tensor of any type values
|
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
|
||||||
1. `initial_h`: memref of any type values or tensor of any type values
|
1. `initial_h`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -1315,8 +1315,8 @@ ONNX GRU operation
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values or none type
|
||||||
1. `Y_h`: memref of any type values or tensor of any type values
|
1. `Y_h`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.GatherElements (ONNXGatherElementsOp)
|
### onnx.GatherElements (ONNXGatherElementsOp)
|
||||||
ONNX GatherElements operation
|
ONNX GatherElements operation
|
||||||
|
@ -1558,33 +1558,6 @@ ONNX Gather operation
|
||||||
|
|
||||||
1. `output`: memref of any type values or tensor of any type values
|
1. `output`: memref of any type values or tensor of any type values
|
||||||
|
|
||||||
### onnx.GemmNoBias (ONNXGemmNoBiasOp)
|
|
||||||
ONNX general matrix multiply operation without bias.
|
|
||||||
|
|
||||||
#### Description:
|
|
||||||
|
|
||||||
|
|
||||||
The "onnx.Gemm" generic matrix multiplication without bias.
|
|
||||||
|
|
||||||
|
|
||||||
#### Operands:
|
|
||||||
|
|
||||||
1. `A`: memref of any type values or tensor of any type values
|
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
|
||||||
|
|
||||||
#### Attributes:
|
|
||||||
|
|
||||||
| Attribute | MLIR Type | Description |
|
|
||||||
| :-------: | :-------: | ----------- |
|
|
||||||
| `alpha` | `FloatAttr` | 32-bit float attribute attribute |
|
|
||||||
| `beta` | `FloatAttr` | 32-bit float attribute attribute |
|
|
||||||
| `transA` | `IntegerAttr` | 64-bit integer attribute attribute |
|
|
||||||
| `transB` | `IntegerAttr` | 64-bit integer attribute attribute |
|
|
||||||
|
|
||||||
#### Results:
|
|
||||||
|
|
||||||
1. `o_Y`: memref of any type values or tensor of any type values
|
|
||||||
|
|
||||||
### onnx.Gemm (ONNXGemmOp)
|
### onnx.Gemm (ONNXGemmOp)
|
||||||
ONNX Gemm operation
|
ONNX Gemm operation
|
||||||
|
|
||||||
|
@ -1609,7 +1582,7 @@ ONNX Gemm operation
|
||||||
|
|
||||||
1. `A`: memref of any type values or tensor of any type values
|
1. `A`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values
|
||||||
1. `C`: memref of any type values or tensor of any type values
|
1. `C`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -2013,11 +1986,11 @@ ONNX LSTM operation
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `W`: memref of any type values or tensor of any type values
|
1. `W`: memref of any type values or tensor of any type values
|
||||||
1. `R`: memref of any type values or tensor of any type values
|
1. `R`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
1. `sequence_lens`: memref of any type values or tensor of any type values
|
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
|
||||||
1. `initial_h`: memref of any type values or tensor of any type values
|
1. `initial_h`: memref of any type values or tensor of any type values or none type
|
||||||
1. `initial_c`: memref of any type values or tensor of any type values
|
1. `initial_c`: memref of any type values or tensor of any type values or none type
|
||||||
1. `P`: memref of any type values or tensor of any type values
|
1. `P`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -2033,9 +2006,9 @@ ONNX LSTM operation
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values or none type
|
||||||
1. `Y_h`: memref of any type values or tensor of any type values
|
1. `Y_h`: memref of any type values or tensor of any type values or none type
|
||||||
1. `Y_c`: memref of any type values or tensor of any type values
|
1. `Y_c`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.LeakyRelu (ONNXLeakyReluOp)
|
### onnx.LeakyRelu (ONNXLeakyReluOp)
|
||||||
ONNX LeakyRelu operation
|
ONNX LeakyRelu operation
|
||||||
|
@ -2160,24 +2133,24 @@ ONNX Loop operation
|
||||||
""
|
""
|
||||||
" Operator inputs defined as (max_trip_count, condition_var)."
|
" Operator inputs defined as (max_trip_count, condition_var)."
|
||||||
""
|
""
|
||||||
" input ("", ""):"
|
" input (\"\", \"\"):"
|
||||||
" for (int i=0; ; ++i) {"
|
" for (int i=0; ; ++i) {"
|
||||||
" cond = ... // Note this value is ignored, but is required in the body"
|
" cond = ... // Note this value is ignored, but is required in the body"
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
" input ("", cond) // Note this is analogous to a while loop"
|
" input (\"\", cond) // Note this is analogous to a while loop"
|
||||||
" bool cond = ...;"
|
" bool cond = ...;"
|
||||||
" for (int i=0; cond; ++i) {"
|
" for (int i=0; cond; ++i) {"
|
||||||
" cond = ...;"
|
" cond = ...;"
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
" input ("", 1) // Note this is analogous to a do-while loop"
|
" input (\"\", 1) // Note this is analogous to a do-while loop"
|
||||||
" bool cond = true"
|
" bool cond = true"
|
||||||
" for (int i=0; cond; ++i) {"
|
" for (int i=0; cond; ++i) {"
|
||||||
" cond = ...;"
|
" cond = ...;"
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
" input (trip_count, "") // Note this is analogous to a for loop"
|
" input (trip_count, \"\") // Note this is analogous to a for loop"
|
||||||
" int trip_count = ..."
|
" int trip_count = ..."
|
||||||
" for (int i=0; i < trip_count; ++i) {"
|
" for (int i=0; i < trip_count; ++i) {"
|
||||||
" cond = ...; // ignored"
|
" cond = ...; // ignored"
|
||||||
|
@ -2203,15 +2176,15 @@ ONNX Loop operation
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
" graph body-net ("
|
" graph body-net ("
|
||||||
" %i[INT32, scalar] // iteration number"
|
" %i[INT32, scalar]"
|
||||||
" %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used"
|
" %keepgoing[BOOL, scalar]"
|
||||||
" %b_in[INT32, scalar] // incoming value of loop-carried-dependency b"
|
" %b[INT32, scalar]"
|
||||||
" ) {"
|
" ) {"
|
||||||
" %my_local = Add(%a, %b_in)"
|
" %my_local = Add(%a, %b)"
|
||||||
" %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b"
|
" %b_out = Sub(%a, %b)"
|
||||||
" %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition"
|
" %keepgoing_out = Greater(%my_local, %b_out)"
|
||||||
" %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated"
|
" %user_defined_vals = Add(%b, %b)"
|
||||||
" return %keepgoing_out, %b_out, %user_defined_val"
|
" return %keepgoing_out, %b_out, %user_defined_vals"
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
"*Sample equivalent C code*"
|
"*Sample equivalent C code*"
|
||||||
|
@ -2226,51 +2199,31 @@ ONNX Loop operation
|
||||||
" const int max_trip_count = 10; // Analogous to input M"
|
" const int max_trip_count = 10; // Analogous to input M"
|
||||||
" int user_defined_vals[]; // Imagine this is resizable"
|
" int user_defined_vals[]; // Imagine this is resizable"
|
||||||
" /* End implicitly-defined code */"
|
" /* End implicitly-defined code */"
|
||||||
" /* initialize loop-carried variables and scan-output variables */"
|
" for (int i=0; i < max_trip_count && keepgoing; ++i) {"
|
||||||
" bool keepgoing_out = keepgoing"
|
|
||||||
" int b_out = b"
|
|
||||||
""
|
|
||||||
" for (int i=0; i < max_trip_count && keepgoing_out; ++i) {"
|
|
||||||
" /* Implicitly-defined code: bind actual parameter values"
|
|
||||||
" to formal parameter variables of loop-body */"
|
|
||||||
" bool keepgoing_in = keepgoing_out; "
|
|
||||||
" bool b_in = b_out;"
|
|
||||||
""
|
|
||||||
" /* User-defined code (loop body) */"
|
" /* User-defined code (loop body) */"
|
||||||
" int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine"
|
" int my_local = a + b; // Reading values in the enclosing scope is fine"
|
||||||
" b_out = a - b_in;"
|
" b = a - b; // writes fine if we specify b as a loop-carried dependency"
|
||||||
" keepgoing_out = my_local > b_out; "
|
" keepgoing = my_local > b; // keepgoing is a loop-carried dependency"
|
||||||
" user_defined_val = b_in + b_in; // b_in and b_out are different variables"
|
" user_defined_vals[i] = b + b;"
|
||||||
" /* End user-defined code */"
|
" /* End user-defined code */"
|
||||||
""
|
|
||||||
" /* Implicitly defined-code */"
|
|
||||||
" user_defined_vals[i] = user_defined_val // accumulate scan-output values"
|
|
||||||
" }"
|
" }"
|
||||||
" // int t = my_local; // Can't do this. my_local is not accessible here."
|
" // my_local = 123; // Can't do this. my_local was defined in the the body"
|
||||||
""
|
""
|
||||||
" // The values below are bound to the output variables of the loop and therefore accessible"
|
" // These below values are live-out from the loop and therefore accessible"
|
||||||
" // b_out; user_defined_vals; keepgoing_out;"
|
" b_out; user_defined_vals; keepgoing_out;"
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
"There are several things of note in this code snippet:"
|
"There are several things of note in this code snippet:"
|
||||||
""
|
""
|
||||||
"1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can"
|
"1) Values from the enclosing scope (i.e. variable a here) are in scope and can"
|
||||||
" be referenced in the inputs of the loop."
|
" be referenced in the inputs of the loop."
|
||||||
"2) Any values computed in the loop body that needs to be used in a subsequent"
|
"2) Any variables which you wish to make available in the enclosing scope (i.e."
|
||||||
" iteration or after the loop are modelled using a pair of variables in the loop-body,"
|
" the variables b and keepgoing) must be declared as either loop-carried"
|
||||||
" consisting of an input variable (eg., b_in) and an output variable (eg., b_out)."
|
" dependencies (both at the op inputs and output and at the body net input and"
|
||||||
" These are referred to as loop-carried dependences. The loop operation node"
|
" output) or scan_outputs."
|
||||||
" supplies the input value of the input variable for the first iteration, and"
|
"3) Values created in the body cannot be accessed in the enclosing scope."
|
||||||
" returns the output value of the output variable produced by the final"
|
|
||||||
" iteration."
|
|
||||||
"3) Scan_output variables are used to implicitly concatenate values computed across"
|
|
||||||
" all the iterations. In the above example, the value of user_defined_val computed"
|
|
||||||
" over all iterations are concatenated and returned as the value of user_defined_vals"
|
|
||||||
" after the loop."
|
|
||||||
"4) Values created in the body cannot be accessed in the enclosing scope,"
|
|
||||||
" except using the mechanism described above."
|
|
||||||
""
|
""
|
||||||
"Note that the semantics of this op support "diagonal" or "wavefront" execution."
|
"Note that the semantics of this op support \"diagonal\" or \"wavefront\" execution."
|
||||||
"(See Step 3 here for an example:"
|
"(See Step 3 here for an example:"
|
||||||
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
|
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
|
||||||
"Frontends should emit multi-layer RNNs as a series of While operators (with"
|
"Frontends should emit multi-layer RNNs as a series of While operators (with"
|
||||||
|
@ -2280,8 +2233,8 @@ ONNX Loop operation
|
||||||
|
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
|
||||||
1. `M`: memref of any type values or tensor of any type values
|
1. `M`: memref of any type values or tensor of any type values or none type
|
||||||
1. `cond`: memref of any type values or tensor of any type values
|
1. `cond`: memref of any type values or tensor of any type values or none type
|
||||||
1. `v_initial`: memref of any type values or tensor of any type values
|
1. `v_initial`: memref of any type values or tensor of any type values
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
@ -2360,8 +2313,8 @@ ONNX MatMulInteger operation
|
||||||
|
|
||||||
1. `A`: memref of any type values or tensor of any type values
|
1. `A`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values
|
||||||
1. `a_zero_point`: memref of any type values or tensor of any type values
|
1. `a_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
1. `b_zero_point`: memref of any type values or tensor of any type values
|
1. `b_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -2444,7 +2397,7 @@ ONNX MaxPool operation
|
||||||
" ```"
|
" ```"
|
||||||
" pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]"
|
" pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]"
|
||||||
" ```"
|
" ```"
|
||||||
" The output of each pooling window is maximum number of elements exclude pad. "
|
" The output of each pooling window is maximum number of elements exclude pad."
|
||||||
" "
|
" "
|
||||||
|
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
@ -2466,7 +2419,7 @@ ONNX MaxPool operation
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values
|
||||||
1. `Indices`: memref of any type values or tensor of any type values
|
1. `Indices`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp)
|
### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp)
|
||||||
ONNX MaxPool operation with a single output.
|
ONNX MaxPool operation with a single output.
|
||||||
|
@ -2552,7 +2505,7 @@ ONNX MaxUnpool operation
|
||||||
|
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `I`: memref of any type values or tensor of any type values
|
1. `I`: memref of any type values or tensor of any type values
|
||||||
1. `output_shape`: memref of any type values or tensor of any type values
|
1. `output_shape`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -2752,9 +2705,9 @@ ONNX NonMaxSuppression operation
|
||||||
|
|
||||||
1. `boxes`: memref of any type values or tensor of any type values
|
1. `boxes`: memref of any type values or tensor of any type values
|
||||||
1. `scores`: memref of any type values or tensor of any type values
|
1. `scores`: memref of any type values or tensor of any type values
|
||||||
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values
|
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values or none type
|
||||||
1. `iou_threshold`: memref of any type values or tensor of any type values
|
1. `iou_threshold`: memref of any type values or tensor of any type values or none type
|
||||||
1. `score_threshold`: memref of any type values or tensor of any type values
|
1. `score_threshold`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -3067,7 +3020,7 @@ ONNX Pad operation
|
||||||
|
|
||||||
1. `data`: memref of any type values or tensor of any type values
|
1. `data`: memref of any type values or tensor of any type values
|
||||||
1. `pads`: memref of any type values or tensor of any type values
|
1. `pads`: memref of any type values or tensor of any type values
|
||||||
1. `constant_value`: memref of any type values or tensor of any type values
|
1. `constant_value`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -3124,7 +3077,7 @@ ONNX QLinearConv operation
|
||||||
1. `w_zero_point`: memref of any type values or tensor of any type values
|
1. `w_zero_point`: memref of any type values or tensor of any type values
|
||||||
1. `y_scale`: memref of any type values or tensor of any type values
|
1. `y_scale`: memref of any type values or tensor of any type values
|
||||||
1. `y_zero_point`: memref of any type values or tensor of any type values
|
1. `y_zero_point`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -3188,7 +3141,7 @@ ONNX QuantizeLinear operation
|
||||||
|
|
||||||
1. `x`: memref of any type values or tensor of any type values
|
1. `x`: memref of any type values or tensor of any type values
|
||||||
1. `y_scale`: memref of any type values or tensor of any type values
|
1. `y_scale`: memref of any type values or tensor of any type values
|
||||||
1. `y_zero_point`: memref of any type values or tensor of any type values
|
1. `y_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -3270,9 +3223,9 @@ ONNX RNN operation
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `W`: memref of any type values or tensor of any type values
|
1. `W`: memref of any type values or tensor of any type values
|
||||||
1. `R`: memref of any type values or tensor of any type values
|
1. `R`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
1. `sequence_lens`: memref of any type values or tensor of any type values
|
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
|
||||||
1. `initial_h`: memref of any type values or tensor of any type values
|
1. `initial_h`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -3287,8 +3240,8 @@ ONNX RNN operation
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values or none type
|
||||||
1. `Y_h`: memref of any type values or tensor of any type values
|
1. `Y_h`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.RandomNormalLike (ONNXRandomNormalLikeOp)
|
### onnx.RandomNormalLike (ONNXRandomNormalLikeOp)
|
||||||
ONNX RandomNormalLike operation
|
ONNX RandomNormalLike operation
|
||||||
|
@ -3813,14 +3766,14 @@ ONNX Resize operation
|
||||||
|
|
||||||
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
|
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
|
||||||
"Each dimension value of the output tensor is:"
|
"Each dimension value of the output tensor is:"
|
||||||
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified."
|
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified."
|
||||||
|
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `roi`: memref of any type values or tensor of any type values
|
1. `roi`: memref of any type values or tensor of any type values
|
||||||
1. `scales`: memref of any type values or tensor of any type values
|
1. `scales`: memref of any type values or tensor of any type values
|
||||||
1. `sizes`: memref of any type values or tensor of any type values
|
1. `sizes`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -4438,7 +4391,7 @@ ONNX SequenceErase operation
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
|
||||||
1. `input_sequence`: memref of any type values or tensor of any type values
|
1. `input_sequence`: memref of any type values or tensor of any type values
|
||||||
1. `position`: memref of any type values or tensor of any type values
|
1. `position`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -4463,7 +4416,7 @@ ONNX SequenceInsert operation
|
||||||
|
|
||||||
1. `input_sequence`: memref of any type values or tensor of any type values
|
1. `input_sequence`: memref of any type values or tensor of any type values
|
||||||
1. `tensor`: memref of any type values or tensor of any type values
|
1. `tensor`: memref of any type values or tensor of any type values
|
||||||
1. `position`: memref of any type values or tensor of any type values
|
1. `position`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -4680,8 +4633,8 @@ ONNX Slice operation
|
||||||
1. `data`: memref of any type values or tensor of any type values
|
1. `data`: memref of any type values or tensor of any type values
|
||||||
1. `starts`: memref of any type values or tensor of any type values
|
1. `starts`: memref of any type values or tensor of any type values
|
||||||
1. `ends`: memref of any type values or tensor of any type values
|
1. `ends`: memref of any type values or tensor of any type values
|
||||||
1. `axes`: memref of any type values or tensor of any type values
|
1. `axes`: memref of any type values or tensor of any type values or none type
|
||||||
1. `steps`: memref of any type values or tensor of any type values
|
1. `steps`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -4834,7 +4787,7 @@ ONNX SplitToSequence operation
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
|
||||||
1. `input`: memref of any type values or tensor of any type values
|
1. `input`: memref of any type values or tensor of any type values
|
||||||
1. `split`: memref of any type values or tensor of any type values
|
1. `split`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -4902,9 +4855,9 @@ ONNX StringNormalizer operation
|
||||||
"StringNormalization performs string operations for basic cleaning."
|
"StringNormalization performs string operations for basic cleaning."
|
||||||
"This operator has only one input (denoted by X) and only one output"
|
"This operator has only one input (denoted by X) and only one output"
|
||||||
"(denoted by Y). This operator first examines the elements in the X,"
|
"(denoted by Y). This operator first examines the elements in the X,"
|
||||||
"and removes elements specified in "stopwords" attribute."
|
"and removes elements specified in \"stopwords\" attribute."
|
||||||
"After removing stop words, the intermediate result can be further lowercased,"
|
"After removing stop words, the intermediate result can be further lowercased,"
|
||||||
"uppercased, or just returned depending the "case_change_action" attribute."
|
"uppercased, or just returned depending the \"case_change_action\" attribute."
|
||||||
"This operator only accepts [C]- and [1, C]-tensor."
|
"This operator only accepts [C]- and [1, C]-tensor."
|
||||||
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
|
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
|
||||||
"if input shape is [C] and shape [1, 1] if input shape is [1, C]."
|
"if input shape is [C] and shape [1, 1] if input shape is [1, C]."
|
||||||
|
@ -5034,8 +4987,8 @@ ONNX TfIdfVectorizer operation
|
||||||
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
|
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
|
||||||
"Note that we may consider all skips up to S when generating the n-grams."
|
"Note that we may consider all skips up to S when generating the n-grams."
|
||||||
""
|
""
|
||||||
"The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and"
|
"The examples used above are true if mode is \"TF\". If mode is \"IDF\", all the counts larger than 1 would be truncated to 1 and"
|
||||||
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF","
|
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is \"TFIDF\","
|
||||||
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
|
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
|
||||||
""
|
""
|
||||||
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
|
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
|
||||||
|
@ -5123,9 +5076,9 @@ ONNX TopK operation
|
||||||
" contains the indices of the top k elements (original indices from the input"
|
" contains the indices of the top k elements (original indices from the input"
|
||||||
" tensor)."
|
" tensor)."
|
||||||
""
|
""
|
||||||
"If "largest" is 1 (the default value) then the k largest elements are returned."
|
"If \"largest\" is 1 (the default value) then the k largest elements are returned."
|
||||||
"If "sorted" is 1 (the default value) then the resulting k elements will be sorted."
|
"If \"sorted\" is 1 (the default value) then the resulting k elements will be sorted."
|
||||||
"If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined."
|
"If \"sorted\" is 0, order of returned 'Values' and 'Indices' are undefined."
|
||||||
""
|
""
|
||||||
"Given two equivalent values, this operator uses the indices along the axis as"
|
"Given two equivalent values, this operator uses the indices along the axis as"
|
||||||
" a tiebreaker. That is, the element with the lower index will appear first."
|
" a tiebreaker. That is, the element with the lower index will appear first."
|
||||||
|
@ -5184,7 +5137,7 @@ ONNX Unique operation
|
||||||
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
|
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
|
||||||
"The first output tensor 'Y' contains all unique values or subtensors of the input. "
|
"The first output tensor 'Y' contains all unique values or subtensors of the input. "
|
||||||
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
|
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
|
||||||
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". "
|
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. \". "
|
||||||
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
|
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
|
||||||
""
|
""
|
||||||
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
|
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
|
||||||
|
@ -5268,9 +5221,9 @@ ONNX Unique operation
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values
|
||||||
1. `indices`: memref of any type values or tensor of any type values
|
1. `indices`: memref of any type values or tensor of any type values or none type
|
||||||
1. `inverse_indices`: memref of any type values or tensor of any type values
|
1. `inverse_indices`: memref of any type values or tensor of any type values or none type
|
||||||
1. `counts`: memref of any type values or tensor of any type values
|
1. `counts`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.Unsqueeze (ONNXUnsqueezeOp)
|
### onnx.Unsqueeze (ONNXUnsqueezeOp)
|
||||||
ONNX Unsqueeze operation
|
ONNX Unsqueeze operation
|
||||||
|
|
1041
doc/gen_doc.py
1041
doc/gen_doc.py
File diff suppressed because it is too large
Load Diff
|
@ -62,7 +62,21 @@ target_include_directories(onnf_shape_inference
|
||||||
target_link_libraries(onnf_shape_inference ${MLIRLibs})
|
target_link_libraries(onnf_shape_inference ${MLIRLibs})
|
||||||
add_dependencies(onnf_shape_inference gen_krnl_ops)
|
add_dependencies(onnf_shape_inference gen_krnl_ops)
|
||||||
|
|
||||||
add_library(onnf_lower_frontend conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
|
add_library(onnf_lower_frontend
|
||||||
|
conversion/onnx_to_krnl/onnx_to_krnl_common.cpp
|
||||||
|
conversion/onnx_to_krnl/onnx_to_krnl_common.hpp
|
||||||
|
conversion/onnx_to_krnl/math/elementwise.cpp
|
||||||
|
conversion/onnx_to_krnl/math/gemm.cpp
|
||||||
|
conversion/onnx_to_krnl/math/matmul.cpp
|
||||||
|
conversion/onnx_to_krnl/math/reduction.cpp
|
||||||
|
conversion/onnx_to_krnl/math/softmax.cpp
|
||||||
|
conversion/onnx_to_krnl/nn/conv.cpp
|
||||||
|
conversion/onnx_to_krnl/nn/normalization.cpp
|
||||||
|
conversion/onnx_to_krnl/tensor/identity.cpp
|
||||||
|
conversion/onnx_to_krnl/tensor/reshape.cpp
|
||||||
|
conversion/onnx_to_krnl/tensor/transpose.cpp
|
||||||
|
conversion/onnx_to_krnl/tensor/unsqueeze.cpp
|
||||||
|
conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
|
||||||
target_include_directories(onnf_lower_frontend
|
target_include_directories(onnf_lower_frontend
|
||||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
${ONNF_SRC_ROOT})
|
${ONNF_SRC_ROOT})
|
||||||
|
|
|
@ -121,6 +121,7 @@ private:
|
||||||
mlir::MLIRContext &context_;
|
mlir::MLIRContext &context_;
|
||||||
mlir::ModuleOp module_;
|
mlir::ModuleOp module_;
|
||||||
mlir::OpBuilder builder_;
|
mlir::OpBuilder builder_;
|
||||||
|
mlir::Value none_;
|
||||||
// mapping between string name and symbol
|
// mapping between string name and symbol
|
||||||
OnnxOnnfSymbolMapping frontend_symbols_;
|
OnnxOnnfSymbolMapping frontend_symbols_;
|
||||||
|
|
||||||
|
@ -188,8 +189,9 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type elementType =
|
auto elementOnnxType =
|
||||||
convertONNXTypeToMLIRType(input.type().tensor_type().elem_type());
|
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
|
||||||
|
mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
|
||||||
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
||||||
arg_types.emplace_back(
|
arg_types.emplace_back(
|
||||||
mlir::RankedTensorType::get(tensor_dims, elementType));
|
mlir::RankedTensorType::get(tensor_dims, elementType));
|
||||||
|
@ -287,8 +289,8 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
|
std::vector<mlir::NamedAttribute>
|
||||||
const onnx::NodeProto &node) {
|
ImportNodeAttributes(const onnx::NodeProto &node) {
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||||
auto attr = node.attribute(i);
|
auto attr = node.attribute(i);
|
||||||
|
@ -317,21 +319,11 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
|
|
||||||
// combined with 'if constexpr' the issue is the type of the output is
|
|
||||||
// different. alternative way to use variadic output for all the op
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Important onnx node which generates only one output
|
|
||||||
* @param node onnx node
|
|
||||||
* @param nIn number of expected inputs
|
|
||||||
* @param nOut number of expected outputs
|
|
||||||
* @param attrs list of desription for attributes with format {name, type,
|
|
||||||
* default}
|
|
||||||
*/
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
|
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
|
||||||
bool variadicIn = false, bool variadicOut = false) {
|
int expectedNumResults = -1) {
|
||||||
|
bool variadicIn = expectedNumOperands == -1;
|
||||||
|
bool variadicOut = expectedNumResults == -1;
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (const auto &item : node.input()) {
|
for (const auto &item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
|
@ -339,6 +331,10 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!variadicIn)
|
||||||
|
for (auto i = inputs.size(); i < expectedNumOperands; i++)
|
||||||
|
inputs.emplace_back(none_);
|
||||||
|
|
||||||
std::vector<mlir::Type> outputTypes;
|
std::vector<mlir::Type> outputTypes;
|
||||||
for (auto item : node.output()) {
|
for (auto item : node.output()) {
|
||||||
outputTypes.push_back(
|
outputTypes.push_back(
|
||||||
|
@ -347,49 +343,11 @@ private:
|
||||||
|
|
||||||
auto attributes = ImportNodeAttributes(node);
|
auto attributes = ImportNodeAttributes(node);
|
||||||
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
// TODO: Handle optional inputs.
|
||||||
if ((variadicIn || nIn == inputs.size()) &&
|
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||||
(variadicOut || nOut == outputTypes.size())) {
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
auto op =
|
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
|
||||||
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
*(op.getODSResults(i).begin()));
|
||||||
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
|
|
||||||
op.getResult());
|
|
||||||
} else {
|
|
||||||
ImportNodeGeneric(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut,
|
|
||||||
bool variadicIn = false,
|
|
||||||
bool variadicOut = false) {
|
|
||||||
std::vector<mlir::Value> inputs;
|
|
||||||
for (const auto &item : node.input()) {
|
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<mlir::Type> outputTypes;
|
|
||||||
for (auto item : node.output()) {
|
|
||||||
outputTypes.push_back(
|
|
||||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto attributes = ImportNodeAttributes(node);
|
|
||||||
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
|
||||||
|
|
||||||
if ((variadicIn || nIn == inputs.size()) &&
|
|
||||||
(variadicOut || nOut == outputTypes.size())) {
|
|
||||||
auto op =
|
|
||||||
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
|
||||||
for (int i = 0; i < node.output().size(); i++) {
|
|
||||||
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
|
|
||||||
op.getResult(i));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ImportNodeGeneric(node);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -398,8 +356,7 @@ private:
|
||||||
* c++ does not allow template specialization inside a class scope
|
* c++ does not allow template specialization inside a class scope
|
||||||
* a specialized function is used
|
* a specialized function is used
|
||||||
*/
|
*/
|
||||||
void
|
void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
|
||||||
ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
|
|
||||||
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
||||||
// which is determined by the shape of first argument. However, since the
|
// which is determined by the shape of first argument. However, since the
|
||||||
// shape is unknown now, these attributes can be not generated auto
|
// shape is unknown now, these attributes can be not generated auto
|
||||||
|
@ -413,24 +370,20 @@ private:
|
||||||
int nOps = node.input().size();
|
int nOps = node.input().size();
|
||||||
|
|
||||||
if (nOps == 2)
|
if (nOps == 2)
|
||||||
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(
|
buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut);
|
||||||
node, nOps, nOut);
|
|
||||||
else
|
else
|
||||||
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut);
|
buildOperation<mlir::ONNXConvOp>(node, nOps, nOut);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Special handle for MaxPool operations.
|
* Special handle for MaxPool operations.
|
||||||
*/
|
*/
|
||||||
void ImportNodeMaxPool(
|
void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) {
|
||||||
onnx::NodeProto node, int nIn, int nOut) {
|
|
||||||
int nOuts = node.output().size();
|
int nOuts = node.output().size();
|
||||||
if (nOuts == 1) {
|
if (nOuts == 1) {
|
||||||
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(
|
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts);
|
||||||
node, nIn, nOuts);
|
|
||||||
} else {
|
} else {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(
|
buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts);
|
||||||
node, nIn, nOuts);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -441,23 +394,10 @@ private:
|
||||||
int nOuts = node.output().size();
|
int nOuts = node.output().size();
|
||||||
if (nOuts == 1) {
|
if (nOuts == 1) {
|
||||||
// Test mode with one output.
|
// Test mode with one output.
|
||||||
ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn,
|
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts);
|
||||||
nOuts);
|
|
||||||
} else {
|
} else {
|
||||||
// Training mode with four trailing optional outputs. Not handled yet.
|
// Training mode with four trailing optional outputs. Not handled yet.
|
||||||
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
|
buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Special handle for Gemm operations.
|
|
||||||
*/
|
|
||||||
void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) {
|
|
||||||
int nOps = node.input().size();
|
|
||||||
if (nOps == 2) {
|
|
||||||
ImportNodeOneOut<mlir::ONNXGemmNoBiasOp>(node, 2, nOut);
|
|
||||||
} else {
|
|
||||||
ImportNodeOneOut<mlir::ONNXGemmOp>(node, nIn, nOut);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -467,28 +407,14 @@ private:
|
||||||
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
|
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
|
||||||
int nOps = node.input().size();
|
int nOps = node.input().size();
|
||||||
if (nOps == 2) {
|
if (nOps == 2) {
|
||||||
ImportNodeOneOut<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
|
buildOperation<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
|
||||||
} else {
|
} else {
|
||||||
ImportNodeOneOut<mlir::ONNXPadOp>(node, nIn, nOut);
|
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportNode(const onnx::NodeProto &node) {
|
void ImportNode(const onnx::NodeProto &node) {
|
||||||
std::vector<mlir::Value> inputs;
|
llvm::StringRef opName = node.op_type();
|
||||||
for (const auto &item : node.input()) {
|
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<mlir::Type> outputTypes;
|
|
||||||
for (auto item : node.output()) {
|
|
||||||
outputTypes.push_back(
|
|
||||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
|
||||||
|
|
||||||
// the following code is generated by gen_doc.py
|
// the following code is generated by gen_doc.py
|
||||||
// refer to dialect/onnx/onnx.td for details
|
// refer to dialect/onnx/onnx.td for details
|
||||||
|
@ -555,9 +481,11 @@ private:
|
||||||
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
|
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
|
||||||
}
|
}
|
||||||
|
|
||||||
// import nodes in the graph
|
// Create a NoneTyped constant.
|
||||||
auto node = graph.node();
|
none_ =
|
||||||
for (const auto &item : node) {
|
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
|
||||||
|
// Import nodes in the graph.
|
||||||
|
for (const auto &item : graph.node()) {
|
||||||
ImportNode(item);
|
ImportNode(item);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,320 +1,319 @@
|
||||||
//********************************************************
|
//********************************************************
|
||||||
// Warning: Do not modify this file directly
|
// This file is generated on UTC-02/24/2020, 06:29:01.
|
||||||
// This file is automatically generated via script
|
// Do not modify this file directly.
|
||||||
// Details can be found in doc/readonnxdefs.md
|
// This file is automatically generated via script.
|
||||||
|
// Details can be found in doc/readonnxdefs.md .
|
||||||
//********************************************************
|
//********************************************************
|
||||||
|
|
||||||
if (OpName == "DUMMY") {
|
if (opName == "Abs")
|
||||||
}else if (OpName == "Abs") {
|
return buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1);
|
if (opName == "Acos")
|
||||||
}else if (OpName == "Acos") {
|
return buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1);
|
if (opName == "Acosh")
|
||||||
}else if (OpName == "Acosh") {
|
return buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1);
|
if (opName == "Add")
|
||||||
}else if (OpName == "Add") {
|
return buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1);
|
if (opName == "And")
|
||||||
}else if (OpName == "And") {
|
return buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1);
|
if (opName == "ArgMax")
|
||||||
}else if (OpName == "ArgMax") {
|
return buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1);
|
if (opName == "ArgMin")
|
||||||
}else if (OpName == "ArgMin") {
|
return buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1);
|
if (opName == "Asin")
|
||||||
}else if (OpName == "Asin") {
|
return buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1);
|
if (opName == "Asinh")
|
||||||
}else if (OpName == "Asinh") {
|
return buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1);
|
if (opName == "Atan")
|
||||||
}else if (OpName == "Atan") {
|
return buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1);
|
if (opName == "Atanh")
|
||||||
}else if (OpName == "Atanh") {
|
return buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1);
|
if (opName == "AveragePool")
|
||||||
}else if (OpName == "AveragePool") {
|
return buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1);
|
if (opName == "BatchNormalization")
|
||||||
}else if (OpName == "BatchNormalization") {
|
return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
|
||||||
ImportNodeBatchNormalization(node, 5, 5);
|
if (opName == "BitShift")
|
||||||
}else if (OpName == "BitShift") {
|
return buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1);
|
if (opName == "Cast")
|
||||||
}else if (OpName == "Cast") {
|
return buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1);
|
if (opName == "Ceil")
|
||||||
}else if (OpName == "Ceil") {
|
return buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1);
|
if (opName == "Clip")
|
||||||
}else if (OpName == "Clip") {
|
return buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1);
|
if (opName == "Compress")
|
||||||
}else if (OpName == "Compress") {
|
return buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1);
|
if (opName == "Concat")
|
||||||
}else if (OpName == "Concat") {
|
return buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, true, false);
|
if (opName == "ConcatFromSequence")
|
||||||
}else if (OpName == "ConcatFromSequence") {
|
return buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1);
|
if (opName == "Constant")
|
||||||
}else if (OpName == "Constant") {
|
return buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1);
|
if (opName == "ConstantOfShape")
|
||||||
}else if (OpName == "ConstantOfShape") {
|
return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1);
|
if (opName == "Conv")
|
||||||
}else if (OpName == "Conv") {
|
return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeConv(node, 3, 1);
|
if (opName == "ConvInteger")
|
||||||
}else if (OpName == "ConvInteger") {
|
return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1);
|
if (opName == "ConvTranspose")
|
||||||
}else if (OpName == "ConvTranspose") {
|
return buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1);
|
if (opName == "Cos")
|
||||||
}else if (OpName == "Cos") {
|
return buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1);
|
if (opName == "Cosh")
|
||||||
}else if (OpName == "Cosh") {
|
return buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1);
|
if (opName == "CumSum")
|
||||||
}else if (OpName == "CumSum") {
|
return buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1);
|
if (opName == "DepthToSpace")
|
||||||
}else if (OpName == "DepthToSpace") {
|
return buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1);
|
if (opName == "DequantizeLinear")
|
||||||
}else if (OpName == "DequantizeLinear") {
|
return buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1);
|
if (opName == "Det")
|
||||||
}else if (OpName == "Det") {
|
return buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1);
|
if (opName == "Div")
|
||||||
}else if (OpName == "Div") {
|
return buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1);
|
if (opName == "Dropout")
|
||||||
}else if (OpName == "Dropout") {
|
return buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2);
|
if (opName == "DynamicQuantizeLinear")
|
||||||
}else if (OpName == "DynamicQuantizeLinear") {
|
return buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3);
|
if (opName == "Elu")
|
||||||
}else if (OpName == "Elu") {
|
return buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1);
|
if (opName == "Equal")
|
||||||
}else if (OpName == "Equal") {
|
return buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1);
|
if (opName == "Erf")
|
||||||
}else if (OpName == "Erf") {
|
return buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1);
|
if (opName == "Exp")
|
||||||
}else if (OpName == "Exp") {
|
return buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1);
|
if (opName == "Expand")
|
||||||
}else if (OpName == "Expand") {
|
return buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1);
|
if (opName == "EyeLike")
|
||||||
}else if (OpName == "EyeLike") {
|
return buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1);
|
if (opName == "Flatten")
|
||||||
}else if (OpName == "Flatten") {
|
return buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1);
|
if (opName == "Floor")
|
||||||
}else if (OpName == "Floor") {
|
return buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1);
|
if (opName == "GRU")
|
||||||
}else if (OpName == "GRU") {
|
return buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2);
|
if (opName == "Gather")
|
||||||
}else if (OpName == "Gather") {
|
return buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1);
|
if (opName == "GatherElements")
|
||||||
}else if (OpName == "GatherElements") {
|
return buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1);
|
if (opName == "GatherND")
|
||||||
}else if (OpName == "GatherND") {
|
return buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1);
|
if (opName == "Gemm")
|
||||||
}else if (OpName == "Gemm") {
|
return buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeGemm(node, 3, 1);
|
if (opName == "GlobalAveragePool")
|
||||||
}else if (OpName == "GlobalAveragePool") {
|
return buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1);
|
if (opName == "GlobalLpPool")
|
||||||
}else if (OpName == "GlobalLpPool") {
|
return buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1);
|
if (opName == "GlobalMaxPool")
|
||||||
}else if (OpName == "GlobalMaxPool") {
|
return buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1);
|
if (opName == "Greater")
|
||||||
}else if (OpName == "Greater") {
|
return buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1);
|
if (opName == "HardSigmoid")
|
||||||
}else if (OpName == "HardSigmoid") {
|
return buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1);
|
if (opName == "Hardmax")
|
||||||
}else if (OpName == "Hardmax") {
|
return buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1);
|
if (opName == "Identity")
|
||||||
}else if (OpName == "Identity") {
|
return buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1);
|
if (opName == "If")
|
||||||
}else if (OpName == "If") {
|
return buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
|
||||||
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1);
|
if (opName == "InstanceNormalization")
|
||||||
}else if (OpName == "InstanceNormalization") {
|
return buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1);
|
if (opName == "IsInf")
|
||||||
}else if (OpName == "IsInf") {
|
return buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1);
|
if (opName == "IsNaN")
|
||||||
}else if (OpName == "IsNaN") {
|
return buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1);
|
if (opName == "LRN")
|
||||||
}else if (OpName == "LRN") {
|
return buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1);
|
if (opName == "LSTM")
|
||||||
}else if (OpName == "LSTM") {
|
return buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3);
|
if (opName == "LeakyRelu")
|
||||||
}else if (OpName == "LeakyRelu") {
|
return buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1);
|
if (opName == "Less")
|
||||||
}else if (OpName == "Less") {
|
return buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1);
|
if (opName == "Log")
|
||||||
}else if (OpName == "Log") {
|
return buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1);
|
if (opName == "LogSoftmax")
|
||||||
}else if (OpName == "LogSoftmax") {
|
return buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1);
|
if (opName == "Loop")
|
||||||
}else if (OpName == "Loop") {
|
return buildOperation<mlir::ONNXLoopOp>(node);
|
||||||
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1);
|
if (opName == "LpNormalization")
|
||||||
}else if (OpName == "LpNormalization") {
|
return buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1);
|
if (opName == "LpPool")
|
||||||
}else if (OpName == "LpPool") {
|
return buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1);
|
if (opName == "MatMul")
|
||||||
}else if (OpName == "MatMul") {
|
return buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1);
|
if (opName == "MatMulInteger")
|
||||||
}else if (OpName == "MatMulInteger") {
|
return buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1);
|
if (opName == "Max")
|
||||||
}else if (OpName == "Max") {
|
return buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, true, false);
|
if (opName == "MaxPool")
|
||||||
}else if (OpName == "MaxPool") {
|
return ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
||||||
ImportNodeMaxPool(node, 1, 2);
|
if (opName == "MaxRoiPool")
|
||||||
}else if (OpName == "MaxRoiPool") {
|
return buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1);
|
if (opName == "MaxUnpool")
|
||||||
}else if (OpName == "MaxUnpool") {
|
return buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1);
|
if (opName == "Mean")
|
||||||
}else if (OpName == "Mean") {
|
return buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, true, false);
|
if (opName == "MeanVarianceNormalization")
|
||||||
}else if (OpName == "MeanVarianceNormalization") {
|
return buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1);
|
if (opName == "Min")
|
||||||
}else if (OpName == "Min") {
|
return buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, true, false);
|
if (opName == "Mod")
|
||||||
}else if (OpName == "Mod") {
|
return buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1);
|
if (opName == "Mul")
|
||||||
}else if (OpName == "Mul") {
|
return buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1);
|
if (opName == "Multinomial")
|
||||||
}else if (OpName == "Multinomial") {
|
return buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1);
|
if (opName == "Neg")
|
||||||
}else if (OpName == "Neg") {
|
return buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1);
|
if (opName == "NonMaxSuppression")
|
||||||
}else if (OpName == "NonMaxSuppression") {
|
return buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1);
|
if (opName == "NonZero")
|
||||||
}else if (OpName == "NonZero") {
|
return buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1);
|
if (opName == "Not")
|
||||||
}else if (OpName == "Not") {
|
return buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1);
|
if (opName == "OneHot")
|
||||||
}else if (OpName == "OneHot") {
|
return buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1);
|
if (opName == "Or")
|
||||||
}else if (OpName == "Or") {
|
return buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1);
|
if (opName == "PRelu")
|
||||||
}else if (OpName == "PRelu") {
|
return buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1);
|
if (opName == "Pad")
|
||||||
}else if (OpName == "Pad") {
|
return ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodePad(node, 3, 1);
|
if (opName == "Pow")
|
||||||
}else if (OpName == "Pow") {
|
return buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1);
|
if (opName == "QLinearConv")
|
||||||
}else if (OpName == "QLinearConv") {
|
return buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1);
|
if (opName == "QLinearMatMul")
|
||||||
}else if (OpName == "QLinearMatMul") {
|
return buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1);
|
if (opName == "QuantizeLinear")
|
||||||
}else if (OpName == "QuantizeLinear") {
|
return buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1);
|
if (opName == "RNN")
|
||||||
}else if (OpName == "RNN") {
|
return buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2);
|
if (opName == "RandomNormal")
|
||||||
}else if (OpName == "RandomNormal") {
|
return buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1);
|
if (opName == "RandomNormalLike")
|
||||||
}else if (OpName == "RandomNormalLike") {
|
return buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1);
|
if (opName == "RandomUniform")
|
||||||
}else if (OpName == "RandomUniform") {
|
return buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1);
|
if (opName == "RandomUniformLike")
|
||||||
}else if (OpName == "RandomUniformLike") {
|
return buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1);
|
if (opName == "Range")
|
||||||
}else if (OpName == "Range") {
|
return buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1);
|
if (opName == "Reciprocal")
|
||||||
}else if (OpName == "Reciprocal") {
|
return buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1);
|
if (opName == "ReduceL1")
|
||||||
}else if (OpName == "ReduceL1") {
|
return buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1);
|
if (opName == "ReduceL2")
|
||||||
}else if (OpName == "ReduceL2") {
|
return buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1);
|
if (opName == "ReduceLogSum")
|
||||||
}else if (OpName == "ReduceLogSum") {
|
return buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1);
|
if (opName == "ReduceLogSumExp")
|
||||||
}else if (OpName == "ReduceLogSumExp") {
|
return buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1);
|
if (opName == "ReduceMax")
|
||||||
}else if (OpName == "ReduceMax") {
|
return buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1);
|
if (opName == "ReduceMean")
|
||||||
}else if (OpName == "ReduceMean") {
|
return buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1);
|
if (opName == "ReduceMin")
|
||||||
}else if (OpName == "ReduceMin") {
|
return buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1);
|
if (opName == "ReduceProd")
|
||||||
}else if (OpName == "ReduceProd") {
|
return buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1);
|
if (opName == "ReduceSum")
|
||||||
}else if (OpName == "ReduceSum") {
|
return buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1);
|
if (opName == "ReduceSumSquare")
|
||||||
}else if (OpName == "ReduceSumSquare") {
|
return buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1);
|
if (opName == "Relu")
|
||||||
}else if (OpName == "Relu") {
|
return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1);
|
if (opName == "Reshape")
|
||||||
}else if (OpName == "Reshape") {
|
return buildOperation<mlir::ONNXReshapeOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1);
|
if (opName == "Resize")
|
||||||
}else if (OpName == "Resize") {
|
return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1);
|
if (opName == "ReverseSequence")
|
||||||
}else if (OpName == "ReverseSequence") {
|
return buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1);
|
if (opName == "RoiAlign")
|
||||||
}else if (OpName == "RoiAlign") {
|
return buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1);
|
if (opName == "Round")
|
||||||
}else if (OpName == "Round") {
|
return buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1);
|
if (opName == "Scan")
|
||||||
}else if (OpName == "Scan") {
|
return buildOperation<mlir::ONNXScanOp>(node);
|
||||||
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1);
|
if (opName == "Scatter")
|
||||||
}else if (OpName == "Scatter") {
|
return buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1);
|
if (opName == "ScatterElements")
|
||||||
}else if (OpName == "ScatterElements") {
|
return buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1);
|
if (opName == "ScatterND")
|
||||||
}else if (OpName == "ScatterND") {
|
return buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1);
|
if (opName == "Selu")
|
||||||
}else if (OpName == "Selu") {
|
return buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1);
|
if (opName == "SequenceAt")
|
||||||
}else if (OpName == "SequenceAt") {
|
return buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1);
|
if (opName == "SequenceConstruct")
|
||||||
}else if (OpName == "SequenceConstruct") {
|
return buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, true, false);
|
if (opName == "SequenceEmpty")
|
||||||
}else if (OpName == "SequenceEmpty") {
|
return buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1);
|
if (opName == "SequenceErase")
|
||||||
}else if (OpName == "SequenceErase") {
|
return buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1);
|
if (opName == "SequenceInsert")
|
||||||
}else if (OpName == "SequenceInsert") {
|
return buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1);
|
if (opName == "SequenceLength")
|
||||||
}else if (OpName == "SequenceLength") {
|
return buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1);
|
if (opName == "Shape")
|
||||||
}else if (OpName == "Shape") {
|
return buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1);
|
if (opName == "Shrink")
|
||||||
}else if (OpName == "Shrink") {
|
return buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1);
|
if (opName == "Sigmoid")
|
||||||
}else if (OpName == "Sigmoid") {
|
return buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1);
|
if (opName == "Sign")
|
||||||
}else if (OpName == "Sign") {
|
return buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1);
|
if (opName == "Sin")
|
||||||
}else if (OpName == "Sin") {
|
return buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1);
|
if (opName == "Sinh")
|
||||||
}else if (OpName == "Sinh") {
|
return buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1);
|
if (opName == "Size")
|
||||||
}else if (OpName == "Size") {
|
return buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1);
|
if (opName == "Slice")
|
||||||
}else if (OpName == "Slice") {
|
return buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1);
|
if (opName == "Softmax")
|
||||||
}else if (OpName == "Softmax") {
|
return buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1);
|
if (opName == "Softplus")
|
||||||
}else if (OpName == "Softplus") {
|
return buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1);
|
if (opName == "Softsign")
|
||||||
}else if (OpName == "Softsign") {
|
return buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1);
|
if (opName == "SpaceToDepth")
|
||||||
}else if (OpName == "SpaceToDepth") {
|
return buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1);
|
if (opName == "Split")
|
||||||
}else if (OpName == "Split") {
|
return buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
|
||||||
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1);
|
if (opName == "SplitToSequence")
|
||||||
}else if (OpName == "SplitToSequence") {
|
return buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1);
|
if (opName == "Sqrt")
|
||||||
}else if (OpName == "Sqrt") {
|
return buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1);
|
if (opName == "Squeeze")
|
||||||
}else if (OpName == "Squeeze") {
|
return buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1);
|
if (opName == "StringNormalizer")
|
||||||
}else if (OpName == "StringNormalizer") {
|
return buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1);
|
if (opName == "Sub")
|
||||||
}else if (OpName == "Sub") {
|
return buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1);
|
if (opName == "Sum")
|
||||||
}else if (OpName == "Sum") {
|
return buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, true, false);
|
if (opName == "Tan")
|
||||||
}else if (OpName == "Tan") {
|
return buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1);
|
if (opName == "Tanh")
|
||||||
}else if (OpName == "Tanh") {
|
return buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1);
|
if (opName == "TfIdfVectorizer")
|
||||||
}else if (OpName == "TfIdfVectorizer") {
|
return buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1);
|
if (opName == "ThresholdedRelu")
|
||||||
}else if (OpName == "ThresholdedRelu") {
|
return buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1);
|
if (opName == "Tile")
|
||||||
}else if (OpName == "Tile") {
|
return buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1);
|
if (opName == "TopK")
|
||||||
}else if (OpName == "TopK") {
|
return buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2);
|
if (opName == "Transpose")
|
||||||
}else if (OpName == "Transpose") {
|
return buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1);
|
if (opName == "Unique")
|
||||||
}else if (OpName == "Unique") {
|
return buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4);
|
if (opName == "Unsqueeze")
|
||||||
}else if (OpName == "Unsqueeze") {
|
return buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1);
|
if (opName == "Upsample")
|
||||||
}else if (OpName == "Upsample") {
|
return buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1);
|
if (opName == "Where")
|
||||||
}else if (OpName == "Where") {
|
return buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1);
|
if (opName == "Xor")
|
||||||
}else if (OpName == "Xor") {
|
return buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1);
|
|
||||||
}
|
|
||||||
|
|
|
@ -8,404 +8,11 @@
|
||||||
// Krnl IR and standard operations.
|
// Krnl IR and standard operations.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
|
||||||
#include "llvm/ADT/Sequence.h"
|
|
||||||
|
|
||||||
#include "src/dialect/krnl/krnl_helper.hpp"
|
|
||||||
#include "src/dialect/krnl/krnl_ops.hpp"
|
|
||||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
|
||||||
#include "src/pass/passes.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// FrontendToAffine RewritePatterns
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
/// Check is all dimensions are known at compile time.
|
|
||||||
static bool hasAllConstantDimensions(MemRefType type) {
|
|
||||||
auto memRefShape = type.getShape();
|
|
||||||
for (int i = 0; i < memRefShape.size(); ++i)
|
|
||||||
if (memRefShape[i] < 0)
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
|
||||||
static MemRefType convertToMemRefType(Type type) {
|
|
||||||
MemRefType memRefType;
|
|
||||||
auto tensorType = type.dyn_cast<TensorType>();
|
|
||||||
if (tensorType) {
|
|
||||||
assert(tensorType.hasRank() && "expected only ranked shapes");
|
|
||||||
memRefType =
|
|
||||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
|
||||||
} else {
|
|
||||||
memRefType = type.dyn_cast<MemRefType>();
|
|
||||||
}
|
|
||||||
return memRefType;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Insert an allocation and deallocation for the given MemRefType.
|
|
||||||
static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|
||||||
PatternRewriter &rewriter,
|
|
||||||
bool insertDealloc,
|
|
||||||
ArrayRef<Value> operands = {}) {
|
|
||||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
|
||||||
AllocOp alloc;
|
|
||||||
if (!operands.empty()) {
|
|
||||||
auto memRefShape = type.getShape();
|
|
||||||
auto rank = memRefShape.size();
|
|
||||||
|
|
||||||
std::map<int, Value> fromOperands;
|
|
||||||
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
|
||||||
int memRefDimIdx = rank - 1 - reversedIdx;
|
|
||||||
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
|
|
||||||
Value maxDim = nullptr;
|
|
||||||
for (int i = 0; i < operands.size(); i++) {
|
|
||||||
auto operandShape =
|
|
||||||
operands[i].getType().cast<MemRefType>().getShape();
|
|
||||||
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
|
|
||||||
|
|
||||||
if (operandDimIdx < 0)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// In case of operations with broadcasting, the dimension of the
|
|
||||||
// alloc result is the maximum size along each dimension of the
|
|
||||||
// operands.
|
|
||||||
auto operandDim =
|
|
||||||
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
|
|
||||||
if (maxDim) {
|
|
||||||
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
|
|
||||||
operandDim, maxDim);
|
|
||||||
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
|
|
||||||
maxDim);
|
|
||||||
} else {
|
|
||||||
maxDim = operandDim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 4> allocOperands;
|
|
||||||
for (int i = 0; i < rank; ++i)
|
|
||||||
if (memRefShape[i] < 0)
|
|
||||||
allocOperands.push_back(fromOperands[i]);
|
|
||||||
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
|
|
||||||
} else {
|
|
||||||
alloc = rewriter.create<AllocOp>(loc, type);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure to allocate at the beginning of the block if
|
|
||||||
// all dimensions are known.
|
|
||||||
auto *parentBlock = alloc.getOperation()->getBlock();
|
|
||||||
if (hasAllConstantDimensions(type))
|
|
||||||
alloc.getOperation()->moveBefore(&parentBlock->front());
|
|
||||||
|
|
||||||
if (insertDealloc) {
|
|
||||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
|
||||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
|
||||||
}
|
|
||||||
|
|
||||||
return alloc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine if current function returns the result value of the
|
|
||||||
// current op being lowered. If it does then dealloc should not be
|
|
||||||
// inserted.
|
|
||||||
static bool checkInsertDealloc(Operation *currentOp) {
|
|
||||||
auto parentBlock = currentOp->getBlock();
|
|
||||||
|
|
||||||
bool insertDealloc = true;
|
|
||||||
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
|
||||||
assert(currentOp->getNumResults() < 2 &&
|
|
||||||
"No more than one result supported (for now).");
|
|
||||||
// If there is at least one result to investigate.
|
|
||||||
if (currentOp->getNumResults() > 0) {
|
|
||||||
auto result = currentOp->getResult(0);
|
|
||||||
for (const auto &operand : op.getOperands())
|
|
||||||
if (operand == result)
|
|
||||||
insertDealloc = false;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return insertDealloc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a mapping from result type's dimensions to input type's dimensions,
|
|
||||||
// given that the result type is the result of a reduction op over the input
|
|
||||||
// type.
|
|
||||||
std::map<int64_t, int64_t>
|
|
||||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
|
||||||
std::map<int64_t, int64_t> OutInDimMap;
|
|
||||||
int64_t rank = inputTy.getRank();
|
|
||||||
|
|
||||||
// Mark reduction axes.
|
|
||||||
std::vector<bool> isReductionAxis;
|
|
||||||
for (decltype(rank) i = 0; i < rank; ++i) {
|
|
||||||
if (std::find(axes.begin(), axes.end(), i) != axes.end())
|
|
||||||
isReductionAxis.push_back(true);
|
|
||||||
else
|
|
||||||
isReductionAxis.push_back(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
|
|
||||||
// If it is a reduction axis, there is no relationship among dimensions.
|
|
||||||
if (isReductionAxis[inIndex]) {
|
|
||||||
if (keepdims)
|
|
||||||
outIndex++;
|
|
||||||
} else {
|
|
||||||
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
|
|
||||||
outIndex++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return OutInDimMap;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add bounds associated with the op operand to the KRNL iteration pack.
|
|
||||||
// Dynamic dimenions are supported.
|
|
||||||
static void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
|
||||||
Location loc, KrnlIterateOperandPack &pack,
|
|
||||||
Value operand, int index) {
|
|
||||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
|
||||||
if (shape[index] < 0) {
|
|
||||||
pack.pushConstantBound(0);
|
|
||||||
pack.pushOperandBound(
|
|
||||||
rewriter.create<DimOp>(loc, operand, index).getResult());
|
|
||||||
} else {
|
|
||||||
pack.pushConstantBound(0);
|
|
||||||
pack.pushConstantBound(shape[index]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Function that defines the KRNL dialect loops and their respective
|
|
||||||
// optimized version.
|
|
||||||
static KrnlOptimizeLoopsOp
|
|
||||||
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
|
||||||
std::vector<Value> &loops,
|
|
||||||
std::vector<Value> &optimizedLoops, int64_t numLoops) {
|
|
||||||
// Define loops.
|
|
||||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
|
|
||||||
loops.reserve(numLoops);
|
|
||||||
for (auto result : loopsOp.getResults())
|
|
||||||
loops.push_back(result);
|
|
||||||
|
|
||||||
// Define optimized version of the loops.
|
|
||||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
|
|
||||||
optimizedLoops.reserve(numLoops);
|
|
||||||
for (auto result : optimizedLoopsOp.getResults())
|
|
||||||
optimizedLoops.push_back(result);
|
|
||||||
|
|
||||||
return optimizedLoopsOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Function that emits the loops and their optimized version.
|
|
||||||
// The function returns a reference to the inner optimization block.
|
|
||||||
static Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
|
||||||
std::vector<Value> &loops,
|
|
||||||
std::vector<Value> &optimizedLoops,
|
|
||||||
int64_t numLoops) {
|
|
||||||
KrnlOptimizeLoopsOp optimizedLoopsOp =
|
|
||||||
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
|
|
||||||
return &optimizedLoopsOp.region().front();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Function which emits a basic set of loops and optimized loops
|
|
||||||
// for a given operation argument. A reference to the loop optimization
|
|
||||||
// block is returned in the last argument of the function.
|
|
||||||
static void emitKrnlLoopsAndIterationForOperand(
|
|
||||||
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
|
||||||
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
|
||||||
KrnlIterateOp &iterateOp) {
|
|
||||||
// Operand shape.
|
|
||||||
auto shape = operand.getType().cast<MemRefType>().getShape();
|
|
||||||
|
|
||||||
// Number of loops.
|
|
||||||
int64_t rank = shape.size();
|
|
||||||
|
|
||||||
// Define loops and optimized loops.
|
|
||||||
std::vector<Value> optimizedLoops;
|
|
||||||
optimizedLoopsOp =
|
|
||||||
emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
|
||||||
|
|
||||||
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
|
||||||
// Iterate over the loop nest.
|
|
||||||
for (int i = 0; i < rank; ++i)
|
|
||||||
addDimensionToPack(rewriter, loc, pack, operand, i);
|
|
||||||
|
|
||||||
iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
|
||||||
auto elementType = memRefType.getElementType();
|
|
||||||
|
|
||||||
unsigned sizeInBits;
|
|
||||||
if (elementType.isIntOrFloat()) {
|
|
||||||
sizeInBits = elementType.getIntOrFloatBitWidth();
|
|
||||||
} else {
|
|
||||||
auto vectorType = elementType.cast<VectorType>();
|
|
||||||
sizeInBits =
|
|
||||||
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
|
|
||||||
}
|
|
||||||
return llvm::divideCeil(sizeInBits, 8);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get run-time dimension information for unknown dimensions used for
|
|
||||||
// broadcasting.
|
|
||||||
std::map<int, std::map<int, Value>>
|
|
||||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
|
||||||
MemRefType memRefType, ArrayRef<Value> operands) {
|
|
||||||
auto memRefShape = memRefType.getShape();
|
|
||||||
int64_t rank = memRefShape.size();
|
|
||||||
// For unknown dimensions, we need to get dimension values at runtime in
|
|
||||||
// order to do broadcasting.
|
|
||||||
std::map<int, std::map<int, Value>> DimInfo;
|
|
||||||
// For each result dimension, compute the number of sharing operands.
|
|
||||||
// Sharing operands are operands sharing the same index (counting from the
|
|
||||||
// rightmost to the leftmost) for a given dimension.
|
|
||||||
std::map<int, int> sharedDimCount;
|
|
||||||
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
|
||||||
int dimIdx = rank - 1 - reversedIdx;
|
|
||||||
sharedDimCount[dimIdx] = 0;
|
|
||||||
for (int i = 0; i < operands.size(); ++i) {
|
|
||||||
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
|
||||||
if (reversedIdx <= shape.size() - 1)
|
|
||||||
sharedDimCount[dimIdx]++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// An unknown dimension can have a value of 1 or N (N > 1).
|
|
||||||
// If its value is 1, it is broadcasted dimension.
|
|
||||||
// Otherwise, non-broadcasted dimension.
|
|
||||||
// We only care about unknown dimensions whose number of sharing operands is
|
|
||||||
// more than one, since they are potentially broadcasted dimensions.
|
|
||||||
for (int i = 0; i < operands.size(); ++i) {
|
|
||||||
std::map<int, Value> broadcastedDims;
|
|
||||||
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
|
||||||
int size = shape.size();
|
|
||||||
for (int j = 0; j < shape.size(); ++j) {
|
|
||||||
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
|
|
||||||
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
|
|
||||||
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
|
|
||||||
auto isBroadcasted =
|
|
||||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
|
||||||
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DimInfo.insert(std::make_pair(i, broadcastedDims));
|
|
||||||
}
|
|
||||||
return DimInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract induction variables that are used for broadcasting values of a
|
|
||||||
// given operand.
|
|
||||||
std::vector<Value>
|
|
||||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
|
||||||
ArrayRef<Value> loopIVs, Value operand,
|
|
||||||
std::map<int, Value> broadcastedDims) {
|
|
||||||
// `operand` must has a ranked type. This should have been checked by the
|
|
||||||
// shape inference pass.
|
|
||||||
auto operandShape = operand.getType().cast<MemRefType>().getShape();
|
|
||||||
auto rank = operandShape.size();
|
|
||||||
auto loopCount = loopIVs.size();
|
|
||||||
|
|
||||||
std::vector<Value> newLoopIVs;
|
|
||||||
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
|
||||||
auto dimIdx = rank - 1 - reversedIdx;
|
|
||||||
auto loopIdx = loopCount - 1 - reversedIdx;
|
|
||||||
if (operandShape[dimIdx] == 1) {
|
|
||||||
// Broadcasted dimension
|
|
||||||
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
|
||||||
newLoopIVs.insert(newLoopIVs.begin(), zero);
|
|
||||||
} else if ((operandShape[dimIdx] == -1) &&
|
|
||||||
(broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
|
|
||||||
// Unknown dimension, it can have a value of 1 or N (N > 1).
|
|
||||||
// If its value is 1, it is broadcasted dimension.
|
|
||||||
// Otherwise, non-broadcasted dimension.
|
|
||||||
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
|
||||||
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
|
|
||||||
loopIVs[loopIdx]);
|
|
||||||
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
|
||||||
} else {
|
|
||||||
// Non-broadcasted dimension
|
|
||||||
newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return newLoopIVs;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// This is to get a scalar operation of a given type for a specific operation.
|
|
||||||
template <typename Op>
|
|
||||||
struct ScalarOp {
|
|
||||||
using FOp = void;
|
|
||||||
using IOp = void;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename FOp>
|
|
||||||
using ScalarFOp = typename ScalarOp<FOp>::FOp;
|
|
||||||
template <typename IOp>
|
|
||||||
using ScalarIOp = typename ScalarOp<IOp>::IOp;
|
|
||||||
|
|
||||||
// Get the identity element of a operation.
|
|
||||||
// Return NULL if the function does not have identity.
|
|
||||||
template <typename DataType, typename Op>
|
|
||||||
DataType getIdentityValue() {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// This is used in the innermost loop of a KrnlIterateOp to insert computation
|
|
||||||
// composed of one or many scalar ops.
|
|
||||||
// Use template specialization for each of different ONNX operations.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
template <typename Op>
|
|
||||||
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) {
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
Type element_type = operands.front().getType();
|
|
||||||
if (element_type.isa<IntegerType>()) {
|
|
||||||
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
|
|
||||||
mlir::None);
|
|
||||||
} else if (element_type.isa<FloatType>()) {
|
|
||||||
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
|
|
||||||
mlir::None);
|
|
||||||
} else {
|
|
||||||
emitError(loc, "unsupported element type");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We divide the operator lowering into different categories.
|
|
||||||
// These categories are mostly similar to the operator categories in ONNX:
|
|
||||||
// https://github.com/onnx/onnx/tree/master/onnx/defs.
|
|
||||||
// Besides, it is better to put operators with the same computation pattern into
|
|
||||||
// the same category, e.g. element-wise operators will belong to the elementwise
|
|
||||||
// category.
|
|
||||||
|
|
||||||
// Math
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc"
|
|
||||||
// Tensor
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
|
|
||||||
// Neural network
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
|
|
||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc"
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// EntryPoint Op lowering to Krnl Entry Point.
|
// EntryPoint Op lowering to Krnl Entry Point.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -427,39 +34,6 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
struct TensorTypeConverter : public TypeConverter {
|
|
||||||
using TypeConverter::TypeConverter;
|
|
||||||
|
|
||||||
TensorTypeConverter() {
|
|
||||||
addConversion(convertType);
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
|
||||||
if (auto type = convertToMemRefType(t)) {
|
|
||||||
results.push_back(type);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
results.push_back(t);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return true if the inputs and outputs of the given function type are
|
|
||||||
/// legal. [Taken from MLIR and adapted to only check the legality of the
|
|
||||||
/// inputs. Once unranked results can be handled gracefully this
|
|
||||||
/// override needs to be removed in favour of the original MLIR one.]
|
|
||||||
bool isSignatureLegal(FunctionType funcType) {
|
|
||||||
return llvm::all_of(funcType.getInputs(),
|
|
||||||
[this](Type type) { return isLegal(type); });
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // end anonymous namespace.
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Frontend to Krnl Dialect lowering pass
|
// Frontend to Krnl Dialect lowering pass
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- elementwise.inc - Elementwise Ops ------------------------------===//
|
//===----- elementwise.cpp - Elementwise Ops ------------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<ONNXAddOp> {
|
struct ScalarOp<ONNXAddOp> {
|
||||||
using FOp = AddFOp;
|
using FOp = AddFOp;
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- gemm.inc - Lowering Gemm Op ------------------------------------===//
|
//===----- gemm.cpp - Lowering Gemm Op ------------------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
template <typename GemmOp>
|
template <typename GemmOp>
|
||||||
struct ONNXGemmOpLowering : public ConversionPattern {
|
struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
ONNXGemmOpLowering(MLIRContext *ctx)
|
ONNXGemmOpLowering(MLIRContext *ctx)
|
||||||
|
@ -17,20 +21,22 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto has_bias = (operands.size() == 3);
|
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
|
||||||
|
|
||||||
Value A, B, C;
|
Value A, B, C;
|
||||||
A = operands[0];
|
A = operands[0];
|
||||||
B = operands[1];
|
B = operands[1];
|
||||||
if (has_bias)
|
if (hasBias)
|
||||||
C = operands[2];
|
C = operands[2];
|
||||||
|
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
|
||||||
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
|
auto alphaAttr =
|
||||||
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
|
FloatAttr::get(memRefType.getElementType(),
|
||||||
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
|
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
|
||||||
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
|
auto betaAttr =
|
||||||
|
FloatAttr::get(memRefType.getElementType(),
|
||||||
|
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||||
|
|
||||||
|
@ -68,8 +74,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
// Define loops.
|
// Define loops.
|
||||||
std::vector<Value> originalLoops;
|
std::vector<Value> originalLoops;
|
||||||
std::vector<Value> optimizedLoops;
|
std::vector<Value> optimizedLoops;
|
||||||
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
Block *optimizationBlock =
|
||||||
optimizedLoops, numLoops);
|
defineLoops(rewriter, loc, originalLoops, optimizedLoops, numLoops);
|
||||||
|
|
||||||
// We have two Krnl loops:
|
// We have two Krnl loops:
|
||||||
// - Outer loop iterates over the output matrix dimensions, and
|
// - Outer loop iterates over the output matrix dimensions, and
|
||||||
|
@ -83,8 +89,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
outerLoops.push_back(originalLoops[i]);
|
outerLoops.push_back(originalLoops[i]);
|
||||||
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||||
}
|
}
|
||||||
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
|
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
|
||||||
optimizedOuterLoops);
|
|
||||||
// Induction variables for the outer loops
|
// Induction variables for the outer loops
|
||||||
for (int i = 0; i < 2; ++i)
|
for (int i = 0; i < 2; ++i)
|
||||||
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
|
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
|
||||||
|
@ -106,20 +111,19 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
int64_t K_B_Idx = (isTransB) ? 1 : 0;
|
int64_t K_B_Idx = (isTransB) ? 1 : 0;
|
||||||
reductionPack.pushConstantBound(0);
|
reductionPack.pushConstantBound(0);
|
||||||
if (ATy.getShape()[K_A_Idx] != -1)
|
if (ATy.getShape()[K_A_Idx] != -1)
|
||||||
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
|
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
|
||||||
|
else if (BTy.getShape()[K_B_Idx] != -1)
|
||||||
|
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
|
||||||
else
|
else
|
||||||
if (BTy.getShape()[K_B_Idx] != -1)
|
reductionPack.pushOperandBound(
|
||||||
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
|
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
|
||||||
else
|
|
||||||
reductionPack.pushOperandBound(
|
|
||||||
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
|
|
||||||
|
|
||||||
// Get run-time dimension information for unknown dimensions used for
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
// broadcasting.
|
// broadcasting.
|
||||||
// GemmOp supports unidirectional broadcasting from C to A*B.
|
// GemmOp supports unidirectional broadcasting from C to A*B.
|
||||||
// Hence, it must be enough to get broadcasting information for C only.
|
// Hence, it must be enough to get broadcasting information for C only.
|
||||||
std::map<int, Value> broadcastedDimInfo;
|
std::map<int, Value> broadcastedDimInfo;
|
||||||
if (has_bias) {
|
if (hasBias) {
|
||||||
auto shape = C.getType().cast<MemRefType>().getShape();
|
auto shape = C.getType().cast<MemRefType>().getShape();
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
if (shape[i] < 0) {
|
if (shape[i] < 0) {
|
||||||
|
@ -162,7 +166,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
|
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
|
||||||
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
||||||
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
||||||
if (has_bias) {
|
if (hasBias) {
|
||||||
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
|
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
|
||||||
broadcastedDimInfo);
|
broadcastedDimInfo);
|
||||||
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
||||||
|
@ -210,8 +214,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateLoweringONNXGemmOpPattern(
|
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
MLIRContext *ctx) {
|
||||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
|
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
|
||||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
|
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- matmul.inc - Lowering Matmul Op --------------------------------===//
|
//===----- matmul.cpp - Lowering Matmul Op --------------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXMatMulOpLowering : public ConversionPattern {
|
struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
ONNXMatMulOpLowering(MLIRContext *ctx)
|
ONNXMatMulOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- reduction.inc - Lowering Reduction Ops -------------------------===//
|
//===----- reduction.cpp - Lowering Reduction Ops -------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
// Identity values
|
// Identity values
|
||||||
template <>
|
template <>
|
||||||
float getIdentityValue<float, ONNXReduceMaxOp>(){
|
float getIdentityValue<float, ONNXReduceMaxOp>(){
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- softmax.inc - Softmax Op ---------------------------------------===//
|
//===----- softmax.cpp - Softmax Op ---------------------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- conv.inc - Lowering Convolution Op -----------------------------===//
|
//===----- conv.cpp - Lowering Convolution Op -----------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,13 +8,16 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
|
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
@ -25,12 +28,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
alloc = insertAllocAndDealloc(
|
||||||
{operands[0]});
|
memRefType, loc, rewriter, insertDealloc, {operands[0]});
|
||||||
|
|
||||||
auto resultShape = memRefType.getShape();
|
auto resultShape = memRefType.getShape();
|
||||||
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
auto &inputOperand = operands[0];
|
||||||
auto kernelShape = operands[1].getType().cast<MemRefType>().getShape();
|
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
||||||
|
auto &kernelOperand = operands[1];
|
||||||
|
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
|
||||||
|
|
||||||
// R = ConvNoBias(D, K)
|
// R = ConvNoBias(D, K)
|
||||||
//
|
//
|
||||||
|
@ -91,123 +96,82 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
||||||
Value subchannels;
|
Value subchannels;
|
||||||
if (kernelShape[1] < 0) {
|
if (kernelShape[1] < 0) {
|
||||||
subchannels =
|
subchannels = rewriter.create<DimOp>(loc, kernelOperand, 1).getResult();
|
||||||
rewriter.create<DimOp>(loc, operands[1], 1).getResult();
|
|
||||||
} else {
|
} else {
|
||||||
subchannels = rewriter.create<ConstantIndexOp>(
|
subchannels = rewriter.create<ConstantIndexOp>(loc, kernelShape[1]);
|
||||||
loc, kernelShape[1]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. Define outer loops and emit empty optimization block:
|
// 1. Define outer loops and emit empty optimization block:
|
||||||
int64_t nOuterLoops = (group > 1) ? 3 : 2;
|
int64_t nOuterLoops = (group > 1) ? 3 : 2;
|
||||||
std::vector<Value> outerLoops;
|
BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops);
|
||||||
std::vector<Value> optimizedOuterLoops;
|
outerLoops.createDefineAndOptimizeOp();
|
||||||
Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops,
|
|
||||||
optimizedOuterLoops, nOuterLoops);
|
|
||||||
|
|
||||||
// Prepare iteration arguments over outer loop nest.
|
|
||||||
KrnlIterateOperandPack pack(
|
|
||||||
rewriter, outerLoops, optimizedOuterLoops);
|
|
||||||
// for n = 0 .. N:
|
// for n = 0 .. N:
|
||||||
pack.pushConstantBound(0);
|
int nIndex = outerLoops.pushBounds(0, inputOperand, 0);
|
||||||
if (inputShape[0] < 0)
|
|
||||||
pack.pushOperandBound(
|
|
||||||
rewriter.create<DimOp>(loc, operands[0], 0).getResult());
|
|
||||||
else
|
|
||||||
pack.pushConstantBound(inputShape[0]);
|
|
||||||
// for g = 0 .. N:
|
// for g = 0 .. N:
|
||||||
if (group > 1) {
|
int gIndex = -1;
|
||||||
pack.pushConstantBound(0);
|
if (group > 1)
|
||||||
pack.pushConstantBound(group);
|
gIndex = outerLoops.pushBounds(0, group);
|
||||||
}
|
|
||||||
// for m = 0 .. kernelsPerGroup:
|
// for m = 0 .. kernelsPerGroup:
|
||||||
pack.pushConstantBound(0);
|
int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
|
||||||
pack.pushConstantBound(kernelsPerGroup);
|
// Outer loop iteration
|
||||||
// Outer loop iteration.
|
outerLoops.createIterateOp();
|
||||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
|
||||||
Block &outerIterationBlock = iterateOp.bodyRegion().front();
|
|
||||||
// Emit optimizations for outer loops:
|
|
||||||
rewriter.setInsertionPointToEnd(optimizationBlock);
|
|
||||||
rewriter.create<KrnlReturnLoopsOp>(loc, outerLoops);
|
|
||||||
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
|
||||||
{
|
{
|
||||||
// 2. Emit the body of the outer loop nest.
|
// 2. Emit the body of the outer loop nest.
|
||||||
|
|
||||||
// 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
|
// 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
|
||||||
// If group is not set then the value of the kernel ID is
|
// If group is not set then the value of the kernel ID is
|
||||||
// identical to that of the loop over kernels.
|
// identical to that of the loop over kernels.
|
||||||
Value kernel = outerIterationBlock.getArguments()[1];
|
Value kernel = outerLoops.getInductionVar(mIndex);
|
||||||
if (group > 1) {
|
if (group > 1) {
|
||||||
// Middle loop is over groups and third loop is over the
|
// Middle loop is over groups and third loop is over the
|
||||||
// kernel identifiers in the current group.
|
// kernel identifiers in the current group.
|
||||||
auto kernelsOffset = rewriter.create<MulIOp>(loc,
|
auto kernelsOffset = rewriter.create<MulIOp>(
|
||||||
outerIterationBlock.getArguments()[1],
|
loc, outerLoops.getInductionVar(gIndex), kernelsPerGroupValue);
|
||||||
kernelsPerGroupValue);
|
kernel = rewriter.create<AddIOp>(
|
||||||
kernel = rewriter.create<AddIOp>(loc, kernelsOffset,
|
loc, kernelsOffset, outerLoops.getInductionVar(mIndex));
|
||||||
outerIterationBlock.getArguments()[2]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2.2 Define spatial loops
|
// 2.2 Define spatial loops
|
||||||
int64_t nSpatialLoops = resultShape.size() - 2;
|
int64_t nSpatialLoops = resultShape.size() - 2;
|
||||||
std::vector<Value> spatialLoops;
|
BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops);
|
||||||
std::vector<Value> optimizedSpatialLoops;
|
spatialLoops.createDefineAndOptimizeOp();
|
||||||
Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops,
|
|
||||||
optimizedSpatialLoops, nSpatialLoops);
|
|
||||||
|
|
||||||
// 2.3 Prepare iteration arguments for spatial loop nest.
|
|
||||||
KrnlIterateOperandPack spatialPack(
|
|
||||||
rewriter, spatialLoops, optimizedSpatialLoops);
|
|
||||||
for (int i = 2; i < resultShape.size(); ++i)
|
for (int i = 2; i < resultShape.size(); ++i)
|
||||||
addDimensionToPack(rewriter, loc, spatialPack, alloc, i);
|
spatialLoops.pushBounds(0, alloc, i);
|
||||||
|
|
||||||
// 2.4 Emit loop nest over output spatial dimensions.
|
// 2.4 Emit loop nest over output spatial dimensions.
|
||||||
// for rX = 0 .. RX
|
// for rX = 0 .. RX
|
||||||
auto spatialIterateOp =
|
spatialLoops.createIterateOp();
|
||||||
rewriter.create<KrnlIterateOp>(loc, spatialPack);
|
rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock());
|
||||||
Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front();
|
|
||||||
// 2.5 Emit optimizations for outer loops:
|
|
||||||
rewriter.setInsertionPointToEnd(optSpatialLoopBlock);
|
|
||||||
rewriter.create<KrnlReturnLoopsOp>(loc, spatialLoops);
|
|
||||||
rewriter.setInsertionPointToStart(&spatialIterationBlock);
|
|
||||||
{
|
{
|
||||||
// 3. Emit the body of the spatial loop nest.
|
// 3. Emit the body of the spatial loop nest.
|
||||||
// 3.1 Emit: R[n][kernel][r1][r2] = 0;
|
// 3.1 Emit: R[n][kernel][r1][r2] = 0;
|
||||||
SmallVector<Value, 4> resultIndices;
|
SmallVector<Value, 4> resultIndices;
|
||||||
// n
|
// n
|
||||||
resultIndices.emplace_back(outerIterationBlock.getArguments()[0]);
|
resultIndices.emplace_back(outerLoops.getInductionVar(nIndex));
|
||||||
// kernel
|
// kernel
|
||||||
resultIndices.emplace_back(kernel);
|
resultIndices.emplace_back(kernel);
|
||||||
// rX
|
// rX
|
||||||
for (auto arg : spatialIterationBlock.getArguments())
|
for (auto arg : spatialLoops.getIterateBlock()->getArguments())
|
||||||
resultIndices.emplace_back(arg);
|
resultIndices.emplace_back(arg);
|
||||||
// Store initializer value into output location.
|
// Store initializer value into output location.
|
||||||
rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
|
rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
|
||||||
|
|
||||||
// 3.2 Define inner loops.
|
// 3.2 Define inner loops.
|
||||||
int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
|
int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
|
||||||
std::vector<Value> innerLoops;
|
BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops);
|
||||||
std::vector<Value> optimizedInnerLoops;
|
innerLoops.createDefineAndOptimizeOp();
|
||||||
Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops,
|
|
||||||
optimizedInnerLoops, nInnerLoops);
|
|
||||||
|
|
||||||
// 3.3 Prepare iteration arguments for inner loop nest.
|
|
||||||
KrnlIterateOperandPack innerPack(
|
|
||||||
rewriter, innerLoops, optimizedInnerLoops);
|
|
||||||
// for c = 0 .. C/group
|
// for c = 0 .. C/group
|
||||||
innerPack.pushConstantBound(0);
|
int cIndex = innerLoops.pushBounds(0, kernelShape[1]);
|
||||||
innerPack.pushConstantBound(kernelShape[1]);
|
|
||||||
// for Kx = 0 .. KX
|
// for Kx = 0 .. KX
|
||||||
for (int i = 2; i < kernelShape.size(); ++i)
|
for (int i = 2; i < kernelShape.size(); ++i)
|
||||||
addDimensionToPack(rewriter, loc, innerPack, operands[1], i);
|
innerLoops.pushBounds(0, kernelOperand, i);
|
||||||
|
|
||||||
// 3.4 Emit inner loop nest.
|
// 3.4 Emit inner loop nest.
|
||||||
auto innerIterateOp =
|
innerLoops.createIterateOp();
|
||||||
rewriter.create<KrnlIterateOp>(loc, innerPack);
|
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
|
||||||
Block &innerIterationBlock = innerIterateOp.bodyRegion().front();
|
|
||||||
// 3.5 Emit optimizations for outer loops:
|
|
||||||
rewriter.setInsertionPointToEnd(optInnerLoopBlock);
|
|
||||||
rewriter.create<KrnlReturnLoopsOp>(loc, innerLoops);
|
|
||||||
rewriter.setInsertionPointToStart(&innerIterationBlock);
|
|
||||||
{
|
{
|
||||||
// 4. Emit inner loop body
|
// 4. Emit inner loop body
|
||||||
// R[n][kernel][r1][r2] =
|
// R[n][kernel][r1][r2] =
|
||||||
|
@ -217,13 +181,13 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
// 4.1 Prepare indices for accesing the data tensor.
|
// 4.1 Prepare indices for accesing the data tensor.
|
||||||
SmallVector<Value, 4> dataIndices;
|
SmallVector<Value, 4> dataIndices;
|
||||||
// n
|
// n
|
||||||
dataIndices.emplace_back(outerIterationBlock.getArguments()[0]);
|
dataIndices.emplace_back(outerLoops.getInductionVar(nIndex));
|
||||||
// g * (C / group) + c
|
// g * (C / group) + c
|
||||||
Value channelDepth = innerIterationBlock.getArguments()[0];
|
Value channelDepth = innerLoops.getInductionVar(cIndex);
|
||||||
if (group > 1)
|
if (group > 1)
|
||||||
channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
|
channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
|
||||||
rewriter.create<MulIOp>(loc, subchannels,
|
rewriter.create<MulIOp>(
|
||||||
outerIterationBlock.getArguments()[1]));
|
loc, subchannels, outerLoops.getInductionVar(gIndex)));
|
||||||
dataIndices.emplace_back(channelDepth);
|
dataIndices.emplace_back(channelDepth);
|
||||||
// sX * rX + kX
|
// sX * rX + kX
|
||||||
auto stridesAttribute = convOp.stridesAttr();
|
auto stridesAttribute = convOp.stridesAttr();
|
||||||
|
@ -233,15 +197,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
for (auto stride : stridesAttribute.getValue())
|
for (auto stride : stridesAttribute.getValue())
|
||||||
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
|
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
|
||||||
for (int i = 0; i < kernelShape.size() - 2; ++i) {
|
for (int i = 0; i < kernelShape.size() - 2; ++i) {
|
||||||
Value spatialIndex = spatialIterationBlock.getArguments()[i];
|
Value spatialIndex = spatialLoops.getInductionVar(i);
|
||||||
// If strides are present then emit the correct access index.
|
// If strides are present then emit the correct access index.
|
||||||
if (stridesAttribute && strides[i] > 1)
|
if (stridesAttribute && strides[i] > 1)
|
||||||
spatialIndex = rewriter.create<MulIOp>(loc,
|
spatialIndex = rewriter.create<MulIOp>(loc,
|
||||||
rewriter.create<ConstantIndexOp>(loc, strides[i]),
|
rewriter.create<ConstantIndexOp>(loc, strides[i]),
|
||||||
spatialIterationBlock.getArguments()[i]);
|
spatialLoops.getInductionVar(i));
|
||||||
dataIndices.emplace_back(
|
dataIndices.emplace_back(rewriter.create<AddIOp>(
|
||||||
rewriter.create<AddIOp>(loc, spatialIndex,
|
loc, spatialIndex, innerLoops.getInductionVar(i + 1)));
|
||||||
innerIterationBlock.getArguments()[i+1]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4.2 Prepare indices for accessing the kernel tensor.
|
// 4.2 Prepare indices for accessing the kernel tensor.
|
||||||
|
@ -249,17 +212,16 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
// kernel
|
// kernel
|
||||||
kernelIndices.emplace_back(kernel);
|
kernelIndices.emplace_back(kernel);
|
||||||
// c
|
// c
|
||||||
kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]);
|
kernelIndices.emplace_back(innerLoops.getInductionVar(cIndex));
|
||||||
// kX
|
// kX
|
||||||
for (int i = 0; i < kernelShape.size() - 2; ++i)
|
for (int i = 0; i < kernelShape.size() - 2; ++i)
|
||||||
kernelIndices.emplace_back(
|
kernelIndices.emplace_back(innerLoops.getInductionVar(i + 1));
|
||||||
innerIterationBlock.getArguments()[i+1]);
|
|
||||||
|
|
||||||
// 4.3 Compute convolution.
|
// 4.3 Compute convolution.
|
||||||
auto loadData =
|
auto loadData =
|
||||||
rewriter.create<LoadOp>(loc, operands[0], dataIndices);
|
rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
|
||||||
auto loadKernel =
|
auto loadKernel =
|
||||||
rewriter.create<LoadOp>(loc, operands[1], kernelIndices);
|
rewriter.create<LoadOp>(loc, kernelOperand, kernelIndices);
|
||||||
auto loadPartialSum =
|
auto loadPartialSum =
|
||||||
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
||||||
Value result = rewriter.create<AddFOp>(loc, loadPartialSum,
|
Value result = rewriter.create<AddFOp>(loc, loadPartialSum,
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- normalization.inc - Lowering Normalization Ops -----------------===//
|
//===----- normalization.cpp - Lowering Normalization Ops -----------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
||||||
ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
|
ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(
|
: ConversionPattern(
|
|
@ -0,0 +1,324 @@
|
||||||
|
//====-- onnx_to_krnl_common.cpp - ONNX dialects to Krnl lowering ---------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains common code shared by the functions performing the
|
||||||
|
// lowering to the KRNL dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
/// Check is all dimensions are known at compile time.
|
||||||
|
bool hasAllConstantDimensions(MemRefType type) {
|
||||||
|
auto memRefShape = type.getShape();
|
||||||
|
for (int i = 0; i < memRefShape.size(); ++i)
|
||||||
|
if (memRefShape[i] < 0)
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||||
|
MemRefType convertToMemRefType(Type type) {
|
||||||
|
MemRefType memRefType;
|
||||||
|
auto tensorType = type.dyn_cast<TensorType>();
|
||||||
|
if (tensorType) {
|
||||||
|
assert(tensorType.hasRank() && "expected only ranked shapes");
|
||||||
|
memRefType =
|
||||||
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
|
} else {
|
||||||
|
memRefType = type.dyn_cast<MemRefType>();
|
||||||
|
}
|
||||||
|
return memRefType;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
|
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
|
PatternRewriter &rewriter,
|
||||||
|
bool insertDealloc,
|
||||||
|
ArrayRef<Value> operands) {
|
||||||
|
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||||
|
AllocOp alloc;
|
||||||
|
if (!operands.empty()) {
|
||||||
|
auto memRefShape = type.getShape();
|
||||||
|
auto rank = memRefShape.size();
|
||||||
|
|
||||||
|
std::map<int, Value> fromOperands;
|
||||||
|
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
|
int memRefDimIdx = rank - 1 - reversedIdx;
|
||||||
|
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
|
||||||
|
Value maxDim = nullptr;
|
||||||
|
for (int i = 0; i < operands.size(); i++) {
|
||||||
|
auto operandShape =
|
||||||
|
operands[i].getType().cast<MemRefType>().getShape();
|
||||||
|
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
|
||||||
|
|
||||||
|
if (operandDimIdx < 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// In case of operations with broadcasting, the dimension of the
|
||||||
|
// alloc result is the maximum size along each dimension of the
|
||||||
|
// operands.
|
||||||
|
auto operandDim =
|
||||||
|
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
|
||||||
|
if (maxDim) {
|
||||||
|
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
|
||||||
|
operandDim, maxDim);
|
||||||
|
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
|
||||||
|
maxDim);
|
||||||
|
} else {
|
||||||
|
maxDim = operandDim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> allocOperands;
|
||||||
|
for (int i = 0; i < rank; ++i)
|
||||||
|
if (memRefShape[i] < 0)
|
||||||
|
allocOperands.push_back(fromOperands[i]);
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
|
||||||
|
} else {
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure to allocate at the beginning of the block if
|
||||||
|
// all dimensions are known.
|
||||||
|
auto *parentBlock = alloc.getOperation()->getBlock();
|
||||||
|
if (hasAllConstantDimensions(type))
|
||||||
|
alloc.getOperation()->moveBefore(&parentBlock->front());
|
||||||
|
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
|
||||||
|
return alloc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine if current function returns the result value of the
|
||||||
|
// current op being lowered. If it does then dealloc should not be
|
||||||
|
// inserted.
|
||||||
|
bool checkInsertDealloc(Operation *currentOp) {
|
||||||
|
auto parentBlock = currentOp->getBlock();
|
||||||
|
|
||||||
|
bool insertDealloc = true;
|
||||||
|
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
||||||
|
assert(currentOp->getNumResults() < 2 &&
|
||||||
|
"No more than one result supported (for now).");
|
||||||
|
// If there is at least one result to investigate.
|
||||||
|
if (currentOp->getNumResults() > 0) {
|
||||||
|
auto result = currentOp->getResult(0);
|
||||||
|
for (const auto &operand : op.getOperands())
|
||||||
|
if (operand == result)
|
||||||
|
insertDealloc = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return insertDealloc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||||
|
// given that the result type is the result of a reduction op over the input
|
||||||
|
// type.
|
||||||
|
std::map<int64_t, int64_t>
|
||||||
|
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
|
||||||
|
std::map<int64_t, int64_t> OutInDimMap;
|
||||||
|
int64_t rank = inputTy.getRank();
|
||||||
|
|
||||||
|
// Mark reduction axes.
|
||||||
|
std::vector<bool> isReductionAxis;
|
||||||
|
for (decltype(rank) i = 0; i < rank; ++i) {
|
||||||
|
if (std::find(axes.begin(), axes.end(), i) != axes.end())
|
||||||
|
isReductionAxis.push_back(true);
|
||||||
|
else
|
||||||
|
isReductionAxis.push_back(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
|
||||||
|
// If it is a reduction axis, there is no relationship among dimensions.
|
||||||
|
if (isReductionAxis[inIndex]) {
|
||||||
|
if (keepdims)
|
||||||
|
outIndex++;
|
||||||
|
} else {
|
||||||
|
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
|
||||||
|
outIndex++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return OutInDimMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||||
|
// Dynamic dimenions are supported.
|
||||||
|
void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, KrnlIterateOperandPack &pack,
|
||||||
|
Value operand, int index) {
|
||||||
|
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||||
|
if (shape[index] < 0) {
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushOperandBound(
|
||||||
|
rewriter.create<DimOp>(loc, operand, index).getResult());
|
||||||
|
} else {
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushConstantBound(shape[index]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function that defines the KRNL dialect loops and their respective
|
||||||
|
// optimized version.
|
||||||
|
KrnlOptimizeLoopsOp
|
||||||
|
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
std::vector<Value> &loops,
|
||||||
|
std::vector<Value> &optimizedLoops, int64_t numLoops) {
|
||||||
|
// Define loops.
|
||||||
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
|
||||||
|
loops.reserve(numLoops);
|
||||||
|
for (auto result : loopsOp.getResults())
|
||||||
|
loops.push_back(result);
|
||||||
|
|
||||||
|
// Define optimized version of the loops.
|
||||||
|
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
|
||||||
|
optimizedLoops.reserve(numLoops);
|
||||||
|
for (auto result : optimizedLoopsOp.getResults())
|
||||||
|
optimizedLoops.push_back(result);
|
||||||
|
|
||||||
|
return optimizedLoopsOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function that emits the loops and their optimized version.
|
||||||
|
// The function returns a reference to the inner optimization block.
|
||||||
|
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
std::vector<Value> &loops,
|
||||||
|
std::vector<Value> &optimizedLoops,
|
||||||
|
int64_t numLoops) {
|
||||||
|
KrnlOptimizeLoopsOp optimizedLoopsOp =
|
||||||
|
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
|
||||||
|
return &optimizedLoopsOp.region().front();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function which emits a basic set of loops and optimized loops
|
||||||
|
// for a given operation argument. A reference to the loop optimization
|
||||||
|
// block is returned in the last argument of the function.
|
||||||
|
void emitKrnlLoopsAndIterationForOperand(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
||||||
|
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
||||||
|
KrnlIterateOp &iterateOp) {
|
||||||
|
// Operand shape.
|
||||||
|
auto shape = operand.getType().cast<MemRefType>().getShape();
|
||||||
|
|
||||||
|
// Number of loops.
|
||||||
|
int64_t rank = shape.size();
|
||||||
|
|
||||||
|
// Define loops and optimized loops.
|
||||||
|
std::vector<Value> optimizedLoops;
|
||||||
|
optimizedLoopsOp =
|
||||||
|
emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
||||||
|
|
||||||
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
|
// Iterate over the loop nest.
|
||||||
|
for (int i = 0; i < rank; ++i)
|
||||||
|
addDimensionToPack(rewriter, loc, pack, operand, i);
|
||||||
|
|
||||||
|
iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
||||||
|
auto elementType = memRefType.getElementType();
|
||||||
|
|
||||||
|
unsigned sizeInBits;
|
||||||
|
if (elementType.isIntOrFloat()) {
|
||||||
|
sizeInBits = elementType.getIntOrFloatBitWidth();
|
||||||
|
} else {
|
||||||
|
auto vectorType = elementType.cast<VectorType>();
|
||||||
|
sizeInBits =
|
||||||
|
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
|
||||||
|
}
|
||||||
|
return llvm::divideCeil(sizeInBits, 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
|
// broadcasting.
|
||||||
|
std::map<int, std::map<int, Value>>
|
||||||
|
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
MemRefType memRefType, ArrayRef<Value> operands) {
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
int64_t rank = memRefShape.size();
|
||||||
|
// For unknown dimensions, we need to get dimension values at runtime in
|
||||||
|
// order to do broadcasting.
|
||||||
|
std::map<int, std::map<int, Value>> DimInfo;
|
||||||
|
// For each result dimension, compute the number of sharing operands.
|
||||||
|
// Sharing operands are operands sharing the same index (counting from the
|
||||||
|
// rightmost to the leftmost) for a given dimension.
|
||||||
|
std::map<int, int> sharedDimCount;
|
||||||
|
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
|
int dimIdx = rank - 1 - reversedIdx;
|
||||||
|
sharedDimCount[dimIdx] = 0;
|
||||||
|
for (int i = 0; i < operands.size(); ++i) {
|
||||||
|
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||||
|
if (reversedIdx <= shape.size() - 1)
|
||||||
|
sharedDimCount[dimIdx]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// An unknown dimension can have a value of 1 or N (N > 1).
|
||||||
|
// If its value is 1, it is broadcasted dimension.
|
||||||
|
// Otherwise, non-broadcasted dimension.
|
||||||
|
// We only care about unknown dimensions whose number of sharing operands is
|
||||||
|
// more than one, since they are potentially broadcasted dimensions.
|
||||||
|
for (int i = 0; i < operands.size(); ++i) {
|
||||||
|
std::map<int, Value> broadcastedDims;
|
||||||
|
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||||
|
int size = shape.size();
|
||||||
|
for (int j = 0; j < shape.size(); ++j) {
|
||||||
|
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
|
||||||
|
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
|
||||||
|
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||||
|
auto isBroadcasted =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
||||||
|
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DimInfo.insert(std::make_pair(i, broadcastedDims));
|
||||||
|
}
|
||||||
|
return DimInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract induction variables that are used for broadcasting values of a
|
||||||
|
// given operand.
|
||||||
|
std::vector<Value>
|
||||||
|
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
ArrayRef<Value> loopIVs, Value operand,
|
||||||
|
std::map<int, Value> broadcastedDims) {
|
||||||
|
// `operand` must has a ranked type. This should have been checked by the
|
||||||
|
// shape inference pass.
|
||||||
|
auto operandShape = operand.getType().cast<MemRefType>().getShape();
|
||||||
|
auto rank = operandShape.size();
|
||||||
|
auto loopCount = loopIVs.size();
|
||||||
|
|
||||||
|
std::vector<Value> newLoopIVs;
|
||||||
|
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
|
auto dimIdx = rank - 1 - reversedIdx;
|
||||||
|
auto loopIdx = loopCount - 1 - reversedIdx;
|
||||||
|
if (operandShape[dimIdx] == 1) {
|
||||||
|
// Broadcasted dimension
|
||||||
|
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
newLoopIVs.insert(newLoopIVs.begin(), zero);
|
||||||
|
} else if ((operandShape[dimIdx] == -1) &&
|
||||||
|
(broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
|
||||||
|
// Unknown dimension, it can have a value of 1 or N (N > 1).
|
||||||
|
// If its value is 1, it is broadcasted dimension.
|
||||||
|
// Otherwise, non-broadcasted dimension.
|
||||||
|
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
|
||||||
|
loopIVs[loopIdx]);
|
||||||
|
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
||||||
|
} else {
|
||||||
|
// Non-broadcasted dimension
|
||||||
|
newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newLoopIVs;
|
||||||
|
}
|
|
@ -0,0 +1,217 @@
|
||||||
|
//====-- onnx_to_krnl_common.hpp - ONNX dialects to Krnl lowering ---------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains common code shared by the functions performing the
|
||||||
|
// lowering to the KRNL dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "src/dialect/krnl/krnl_helper.hpp"
|
||||||
|
#include "src/dialect/krnl/krnl_ops.hpp"
|
||||||
|
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||||
|
#include "src/pass/passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Common functions used when lowering the ONNX frontend dialect to KRNL.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Check is all dimensions are known at compile time.
|
||||||
|
bool hasAllConstantDimensions(MemRefType type);
|
||||||
|
|
||||||
|
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||||
|
MemRefType convertToMemRefType(Type type);
|
||||||
|
|
||||||
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
|
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
|
PatternRewriter &rewriter,
|
||||||
|
bool insertDealloc,
|
||||||
|
ArrayRef<Value> operands = {});
|
||||||
|
|
||||||
|
// Determine if current function returns the result value of the
|
||||||
|
// current op being lowered. If it does then dealloc should not be
|
||||||
|
// inserted.
|
||||||
|
bool checkInsertDealloc(Operation *currentOp);
|
||||||
|
|
||||||
|
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||||
|
// given that the result type is the result of a reduction op over the input
|
||||||
|
// type.
|
||||||
|
std::map<int64_t, int64_t>
|
||||||
|
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
|
||||||
|
|
||||||
|
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||||
|
// Dynamic dimenions are supported.
|
||||||
|
void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, KrnlIterateOperandPack &pack,
|
||||||
|
Value operand, int index);
|
||||||
|
|
||||||
|
// Function that defines the KRNL dialect loops and their respective
|
||||||
|
// optimized version.
|
||||||
|
KrnlOptimizeLoopsOp
|
||||||
|
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
std::vector<Value> &loops,
|
||||||
|
std::vector<Value> &optimizedLoops, int64_t numLoops);
|
||||||
|
|
||||||
|
// Function that emits the loops and their optimized version.
|
||||||
|
// The function returns a reference to the inner optimization block.
|
||||||
|
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
std::vector<Value> &loops,
|
||||||
|
std::vector<Value> &optimizedLoops,
|
||||||
|
int64_t numLoops);
|
||||||
|
|
||||||
|
// Function which emits a basic set of loops and optimized loops
|
||||||
|
// for a given operation argument. A reference to the loop optimization
|
||||||
|
// block is returned in the last argument of the function.
|
||||||
|
void emitKrnlLoopsAndIterationForOperand(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
||||||
|
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
||||||
|
KrnlIterateOp &iterateOp);
|
||||||
|
|
||||||
|
unsigned getMemRefEltSizeInBytes(MemRefType memRefType);
|
||||||
|
|
||||||
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
|
// broadcasting.
|
||||||
|
std::map<int, std::map<int, Value>>
|
||||||
|
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
MemRefType memRefType, ArrayRef<Value> operands);
|
||||||
|
|
||||||
|
// Extract induction variables that are used for broadcasting values of a
|
||||||
|
// given operand.
|
||||||
|
std::vector<Value>
|
||||||
|
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
ArrayRef<Value> loopIVs, Value operand,
|
||||||
|
std::map<int, Value> broadcastedDims);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This is to get a scalar operation of a given type for a specific operation.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <typename Op>
|
||||||
|
struct ScalarOp {
|
||||||
|
using FOp = void;
|
||||||
|
using IOp = void;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FOp>
|
||||||
|
using ScalarFOp = typename ScalarOp<FOp>::FOp;
|
||||||
|
template <typename IOp>
|
||||||
|
using ScalarIOp = typename ScalarOp<IOp>::IOp;
|
||||||
|
|
||||||
|
// Get the identity element of a operation.
|
||||||
|
// Return NULL if the function does not have identity.
|
||||||
|
template <typename DataType, typename Op>
|
||||||
|
DataType getIdentityValue() {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This is used in the innermost loop of a KrnlIterateOp to insert computation
|
||||||
|
// composed of one or many scalar ops.
|
||||||
|
// Use template specialization for each of different ONNX operations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <typename Op>
|
||||||
|
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Type element_type = operands.front().getType();
|
||||||
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
|
||||||
|
mlir::None);
|
||||||
|
} else if (element_type.isa<FloatType>()) {
|
||||||
|
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
|
||||||
|
mlir::None);
|
||||||
|
} else {
|
||||||
|
emitError(loc, "unsupported element type");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct TensorTypeConverter : public TypeConverter {
|
||||||
|
using TypeConverter::TypeConverter;
|
||||||
|
|
||||||
|
TensorTypeConverter() {
|
||||||
|
addConversion(convertType);
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||||
|
if (auto type = convertToMemRefType(t)) {
|
||||||
|
results.push_back(type);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
results.push_back(t);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return true if the inputs and outputs of the given function type are
|
||||||
|
/// legal. [Taken from MLIR and adapted to only check the legality of the
|
||||||
|
/// inputs. Once unranked results can be handled gracefully this
|
||||||
|
/// override needs to be removed in favour of the original MLIR one.]
|
||||||
|
bool isSignatureLegal(FunctionType funcType) {
|
||||||
|
return llvm::all_of(funcType.getInputs(),
|
||||||
|
[this](Type type) { return isLegal(type); });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Functions to add lowering patterns for frontend operations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// `math` directory methods:
|
||||||
|
|
||||||
|
void populateLoweringONNXElementwiseOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
|
||||||
|
MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXMatMulOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXReductionOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXSoftmaxOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
// `nn` directory methods:
|
||||||
|
|
||||||
|
void populateLoweringONNXConvOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXNormalizationOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
// `tensor` directory methods:
|
||||||
|
|
||||||
|
void populateLoweringONNXUnsqueezeOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXTransposeOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXReshapeOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXIdentityOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- identity.inc - Lowering Identity Op ----------------------------===//
|
//===----- identity.cpp - Lowering Identity Op ----------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXIdentityOpLowering : public ConversionPattern {
|
struct ONNXIdentityOpLowering : public ConversionPattern {
|
||||||
ONNXIdentityOpLowering(MLIRContext *ctx)
|
ONNXIdentityOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- reshape.inc - Lowering Reshape Op ------------------------------===//
|
//===----- reshape.cpp - Lowering Reshape Op ------------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXReshapeOpLowering : public ConversionPattern {
|
struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
ONNXReshapeOpLowering(MLIRContext *ctx)
|
ONNXReshapeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- transpose.inc - Lowering Transpose Op --------------------------===//
|
//===----- transpose.cpp - Lowering Transpose Op --------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXTransposeOpLowering : public ConversionPattern {
|
struct ONNXTransposeOpLowering : public ConversionPattern {
|
||||||
ONNXTransposeOpLowering(MLIRContext *ctx)
|
ONNXTransposeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,4 @@
|
||||||
//===----- unsqueeze.inc - Lowering Unsqueeze Op --------------------------===//
|
//===----- unsqueeze.cpp - Lowering Unsqueeze Op --------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -8,6 +8,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||||
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
|
@ -1,4 +1,5 @@
|
||||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
|
|
||||||
#include "src/dialect/krnl/krnl_ops.hpp"
|
#include "src/dialect/krnl/krnl_ops.hpp"
|
||||||
|
@ -9,9 +10,8 @@ namespace onnf {
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
ParseResult
|
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||||
KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
|
const Type &operandType, Value &operand) {
|
||||||
Value &operand) {
|
|
||||||
// If operand queue is empty, parse more operands and cache them.
|
// If operand queue is empty, parse more operands and cache them.
|
||||||
if (_operandRefQueue.empty()) {
|
if (_operandRefQueue.empty()) {
|
||||||
// Parse operand types:
|
// Parse operand types:
|
||||||
|
@ -19,7 +19,7 @@ KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
|
||||||
_parser.parseOperandList(operand_refs);
|
_parser.parseOperandList(operand_refs);
|
||||||
|
|
||||||
// Record operands:
|
// Record operands:
|
||||||
for (auto& operand_ref : operand_refs)
|
for (auto &operand_ref : operand_refs)
|
||||||
_operandRefQueue.emplace(operand_ref);
|
_operandRefQueue.emplace(operand_ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType,
|
ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||||
Value &operand) {
|
const Type &operandType, Value &operand) {
|
||||||
if (ParseOptionalOperand(operandType, operand))
|
if (ParseOptionalOperand(operandType, operand))
|
||||||
return _parser.emitError(
|
return _parser.emitError(
|
||||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||||
|
@ -65,8 +65,8 @@ ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
|
void printDimAndSymbolList(Operation::operand_iterator &begin, unsigned numDims,
|
||||||
unsigned numSymbols, OpAsmPrinter& p) {
|
unsigned numSymbols, OpAsmPrinter &p) {
|
||||||
p << '(';
|
p << '(';
|
||||||
p.printOperands(begin, begin + numDims);
|
p.printOperands(begin, begin + numDims);
|
||||||
p << ')';
|
p << ')';
|
||||||
|
@ -81,8 +81,8 @@ void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
|
||||||
}
|
}
|
||||||
|
|
||||||
void printBound(AffineMapAttr boundMap,
|
void printBound(AffineMapAttr boundMap,
|
||||||
Operation::operand_iterator& boundOperandsBeg, const char* prefix,
|
Operation::operand_iterator &boundOperandsBeg, const char *prefix,
|
||||||
OpAsmPrinter& p) {
|
OpAsmPrinter &p) {
|
||||||
AffineMap map = boundMap.getValue();
|
AffineMap map = boundMap.getValue();
|
||||||
|
|
||||||
// Check if this bound should be printed using custom assembly form.
|
// Check if this bound should be printed using custom assembly form.
|
||||||
|
@ -120,9 +120,10 @@ void printBound(AffineMapAttr boundMap,
|
||||||
printDimAndSymbolList(
|
printDimAndSymbolList(
|
||||||
boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p);
|
boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p);
|
||||||
}
|
}
|
||||||
} // namespace onnf
|
} // namespace onnf
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
|
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
|
||||||
if (boundMaps.size() % 2 == 0)
|
if (boundMaps.size() % 2 == 0)
|
||||||
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
||||||
|
@ -130,11 +131,143 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
|
||||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
}
|
}
|
||||||
|
|
||||||
void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
|
void KrnlIterateOperandPack::pushOperandBound(Value operand) {
|
||||||
if (boundMaps.size() % 2 == 0)
|
if (boundMaps.size() % 2 == 0)
|
||||||
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
||||||
AffineMap map = builder.getSymbolIdentityMap();
|
AffineMap map = builder.getSymbolIdentityMap();
|
||||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
_operands.emplace_back(operand);
|
_operands.emplace_back(operand);
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
|
||||||
|
BuildKrnlLoop::BuildKrnlLoop(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, int loopNum)
|
||||||
|
: rewriter(rewriter), loc(loc), originalLoopNum(loopNum), pack(NULL),
|
||||||
|
pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
|
||||||
|
createdIterateOp(false) {
|
||||||
|
if (originalLoopNum <= 0)
|
||||||
|
emitError(loc, "Expected positive number of original loops.");
|
||||||
|
}
|
||||||
|
|
||||||
|
BuildKrnlLoop::BuildKrnlLoop(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand)
|
||||||
|
: BuildKrnlLoop(rewriter, loc,
|
||||||
|
memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
|
||||||
|
|
||||||
|
BuildKrnlLoop::~BuildKrnlLoop() {
|
||||||
|
if (pack)
|
||||||
|
free(pack);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
|
||||||
|
// Insert define loop operation.
|
||||||
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
|
||||||
|
originalLoops.reserve(originalLoopNum);
|
||||||
|
for (auto result : loopsOp.getResults())
|
||||||
|
originalLoops.push_back(result);
|
||||||
|
createdDefineOp = true;
|
||||||
|
|
||||||
|
// Insert optimize loop operation.
|
||||||
|
auto optimizedLoopsOp =
|
||||||
|
rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
|
||||||
|
optLoops.reserve(originalLoopNum);
|
||||||
|
|
||||||
|
// Emit empty optimizations if flag is set.
|
||||||
|
if (withEmptyOptimization) {
|
||||||
|
for (auto result : optimizedLoopsOp.getResults())
|
||||||
|
optLoops.push_back(result);
|
||||||
|
optBlock = &optimizedLoopsOp.region().front();
|
||||||
|
auto ip = rewriter.saveInsertionPoint();
|
||||||
|
rewriter.setInsertionPointToEnd(optBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
rewriter.restoreInsertionPoint(ip);
|
||||||
|
}
|
||||||
|
createdOptimizeOp = true;
|
||||||
|
|
||||||
|
// prepare data structure to push bounds
|
||||||
|
pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
|
||||||
|
}
|
||||||
|
|
||||||
|
int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
|
||||||
|
pack->pushConstantBound(lowerBound);
|
||||||
|
pack->pushConstantBound(upperBound);
|
||||||
|
return pushCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) {
|
||||||
|
pack->pushConstantBound(lowerBound);
|
||||||
|
pack->pushOperandBound(upperBound);
|
||||||
|
return pushCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
|
||||||
|
int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
|
||||||
|
pack->pushConstantBound(lowerBound);
|
||||||
|
|
||||||
|
// Process upperBound as a dimension of the MemRef. Non-constant dimensions
|
||||||
|
// are supported.
|
||||||
|
auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
|
||||||
|
if (shape[upperBoundMemRefIndex] < 0) {
|
||||||
|
if (upperBoundMustBeConstant)
|
||||||
|
emitError(loc, "Bound expected to be constant.");
|
||||||
|
pack->pushOperandBound(
|
||||||
|
rewriter
|
||||||
|
.create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
|
||||||
|
.getResult());
|
||||||
|
} else
|
||||||
|
pack->pushConstantBound(shape[upperBoundMemRefIndex]);
|
||||||
|
|
||||||
|
return pushCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) {
|
||||||
|
pack->pushOperandBound(lowerBound);
|
||||||
|
pack->pushOperandBound(upperBound);
|
||||||
|
return pushCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildKrnlLoop::createIterateOp() {
|
||||||
|
// Loop definition operation is mandatory.
|
||||||
|
if (!createdDefineOp)
|
||||||
|
emitError(loc, "Must create define op before iterate op.");
|
||||||
|
|
||||||
|
// Loop optimization operation is mandatory (for now).
|
||||||
|
if (!createdOptimizeOp)
|
||||||
|
emitError(loc, "Must create optimize op before iterate op.");
|
||||||
|
|
||||||
|
// Check if all bounds have been defined.
|
||||||
|
if (pushCount != originalLoopNum)
|
||||||
|
emitError(loc, "Must push bounds for all original loops.");
|
||||||
|
|
||||||
|
// Emit iteration operation.
|
||||||
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
|
||||||
|
iterBlock = &iterateOp.bodyRegion().front();
|
||||||
|
createdIterateOp = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
|
||||||
|
Value memRefOperand, bool withEmptyOptimization) {
|
||||||
|
// Rank of the MemRef operand. We will emit a loop for each dimension.
|
||||||
|
int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
|
||||||
|
if (originalLoopNum != loopNum)
|
||||||
|
emitError(loc, "Mismatch in loop numbers from constructor and define.");
|
||||||
|
|
||||||
|
// Emit the definition and the optimization operations for the loop nest.
|
||||||
|
createDefineAndOptimizeOp(withEmptyOptimization);
|
||||||
|
|
||||||
|
// Push a lower-upper bound pair for each dimension of the MemRef operand.
|
||||||
|
// The lower bound in this case is always zero.
|
||||||
|
for (int i = 0; i < originalLoopNum; ++i)
|
||||||
|
pushBounds(0, memRefOperand, i);
|
||||||
|
|
||||||
|
// Emit the iteration operation over the current loop nest.
|
||||||
|
createIterateOp();
|
||||||
|
}
|
||||||
|
|
||||||
|
BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
|
||||||
|
// Check if loop iteration variable is within bounds.
|
||||||
|
if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
|
||||||
|
emitError(loc, "Original loop index is out of bounds.");
|
||||||
|
return iterBlock->getArguments()[originalLoopIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
|
@ -8,39 +8,38 @@
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
namespace onnf {
|
namespace onnf {
|
||||||
|
|
||||||
class KrnlDialectOperandParser {
|
class KrnlDialectOperandParser {
|
||||||
public:
|
public:
|
||||||
explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser)
|
explicit KrnlDialectOperandParser(mlir::OpAsmParser &parser)
|
||||||
: _parser(parser), _builder(parser.getBuilder()){};
|
: _parser(parser), _builder(parser.getBuilder()){};
|
||||||
|
|
||||||
// Parse an optional operand.
|
// Parse an optional operand.
|
||||||
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
|
mlir::ParseResult ParseOptionalOperand(
|
||||||
mlir::Value &operand);
|
const mlir::Type &operandType, mlir::Value &operand);
|
||||||
|
|
||||||
// Parse an optional operand and push it to an operand list.
|
// Parse an optional operand and push it to an operand list.
|
||||||
mlir::ParseResult
|
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
|
||||||
ParseOptionalOperand(const mlir::Type &operandType,
|
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||||
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
|
||||||
|
|
||||||
// Parse a required operand.
|
// Parse a required operand.
|
||||||
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
|
mlir::ParseResult ParseOperand(
|
||||||
mlir::Value &operand);
|
const mlir::Type &operandType, mlir::Value &operand);
|
||||||
|
|
||||||
// Parse a required operand and push it to an operand list.
|
// Parse a required operand and push it to an operand list.
|
||||||
mlir::ParseResult
|
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
|
||||||
ParseOperand(const mlir::Type &operandType,
|
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||||
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
|
||||||
|
|
||||||
// Do we have more operands to parse?
|
// Do we have more operands to parse?
|
||||||
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
|
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mlir::OpAsmParser& _parser;
|
mlir::OpAsmParser &_parser;
|
||||||
|
|
||||||
mlir::Builder& _builder;
|
mlir::Builder &_builder;
|
||||||
|
|
||||||
// A queue storing the parsed SSA id references.
|
// A queue storing the parsed SSA id references.
|
||||||
std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
|
std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
|
||||||
|
@ -50,24 +49,24 @@ class KrnlDialectOperandParser {
|
||||||
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
|
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
|
||||||
// Main difference is that it advances the iterator `begin` as it consumes
|
// Main difference is that it advances the iterator `begin` as it consumes
|
||||||
// dimension and symbol operands.
|
// dimension and symbol operands.
|
||||||
void printDimAndSymbolList(mlir::Operation::operand_iterator& begin,
|
void printDimAndSymbolList(mlir::Operation::operand_iterator &begin,
|
||||||
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p);
|
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter &p);
|
||||||
|
|
||||||
// Adapted from:
|
// Adapted from:
|
||||||
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
|
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
|
||||||
// Main difference is that it advances the iterator `boundOperandsBeg` as it
|
// Main difference is that it advances the iterator `boundOperandsBeg` as it
|
||||||
// prints bound.
|
// prints bound.
|
||||||
void printBound(mlir::AffineMapAttr boundMap,
|
void printBound(mlir::AffineMapAttr boundMap,
|
||||||
mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix,
|
mlir::Operation::operand_iterator &boundOperandsBeg, const char *prefix,
|
||||||
mlir::OpAsmPrinter& p);
|
mlir::OpAsmPrinter &p);
|
||||||
} // namespace onnf
|
} // namespace onnf
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
struct KrnlIterateOperandPack {
|
struct KrnlIterateOperandPack {
|
||||||
KrnlIterateOperandPack(mlir::Builder &builder,
|
KrnlIterateOperandPack(mlir::Builder &builder,
|
||||||
llvm::ArrayRef<mlir::Value> inputLoops,
|
llvm::ArrayRef<mlir::Value> inputLoops,
|
||||||
llvm::ArrayRef<mlir::Value> optimizedLoops)
|
llvm::ArrayRef<mlir::Value> optimizedLoops)
|
||||||
: builder(builder), inputLoops(inputLoops),
|
: builder(builder), inputLoops(inputLoops),
|
||||||
optimizedLoops(optimizedLoops) {
|
optimizedLoops(optimizedLoops) {
|
||||||
_operands.insert(
|
_operands.insert(
|
||||||
|
@ -88,7 +87,7 @@ struct KrnlIterateOperandPack {
|
||||||
|
|
||||||
size_t getNumInputLoops() const { return inputLoops.size(); }
|
size_t getNumInputLoops() const { return inputLoops.size(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int _boundIdx = 0;
|
int _boundIdx = 0;
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value, 8> _operands;
|
llvm::SmallVector<mlir::Value, 8> _operands;
|
||||||
|
@ -97,7 +96,124 @@ struct KrnlIterateOperandPack {
|
||||||
|
|
||||||
llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
|
llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
|
||||||
|
|
||||||
mlir::Builder& builder;
|
mlir::Builder &builder;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlir
|
// Helper function to write kernel loops. This class will let us build a single
|
||||||
|
// define/optimize/iterate operation combo. We can then insert optimizations in
|
||||||
|
// the body of the optimization operation, and operations in the body of the
|
||||||
|
// iterate operation.
|
||||||
|
//
|
||||||
|
// The sequence is as follow:
|
||||||
|
//
|
||||||
|
// 1) Create an object giving the rewriter, location, and number of loop in
|
||||||
|
// the original (non optimized) loop.
|
||||||
|
//
|
||||||
|
// 2) Create define & optimize ops (currently paired). Optimizations can then
|
||||||
|
// be added to the inner block of the optimize operation. Make sure to set
|
||||||
|
// the insertion point to that block for optimizations to go in the right
|
||||||
|
// place.
|
||||||
|
//
|
||||||
|
// 3) Push the bounds for each of the original loops. Bounds are pushed in
|
||||||
|
// pairs (lower & upper bounds). There are a few methods to do it depending
|
||||||
|
// on the type of the bounds. When pushing bounds, the method returns a
|
||||||
|
// number that represent the index associated with that iteration (induction
|
||||||
|
// variable and bounds). That index can be used later to extract the
|
||||||
|
// induction variable for reference in computation and/or index calculations
|
||||||
|
// of mem refs.
|
||||||
|
//
|
||||||
|
// 4) Once all the bounds are pushed, create the iterate operation. Once this
|
||||||
|
// is done, we can add operations within the iterate blocks by setting the
|
||||||
|
// insertion point to it. Value of the induction variables can be retrieved
|
||||||
|
// using the proper index (determined when pushin the bounds).
|
||||||
|
|
||||||
|
class BuildKrnlLoop {
|
||||||
|
public:
|
||||||
|
// Create kernel loop builder for a loop nest of depth loopNum.
|
||||||
|
BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
|
||||||
|
|
||||||
|
// Create kernel loop builder for a loop nest of depth equal to the
|
||||||
|
// dimensionality of the operand. An operand of MemRef type is requied.
|
||||||
|
BuildKrnlLoop(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
|
||||||
|
~BuildKrnlLoop();
|
||||||
|
|
||||||
|
// Create define and optimize loop with loopNum original loops. If
|
||||||
|
// withEmptyOptimization is true, the optimization is simply the identity
|
||||||
|
// function (no optimizations).
|
||||||
|
void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
|
||||||
|
|
||||||
|
// Push bounds (lower and upper) for each of the loops (order matters).
|
||||||
|
// The function returns the order number associated with the loop iteration.
|
||||||
|
// This index is used by the getInductionVar call. Non-constant operands
|
||||||
|
// must be of MemRef type.
|
||||||
|
int pushBounds(int64_t lowerBound, int64_t upperBound);
|
||||||
|
int pushBounds(int64_t lowerBound, Value upperBound);
|
||||||
|
int pushBounds(Value lowerBound, Value upperBound);
|
||||||
|
int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
|
||||||
|
int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
|
||||||
|
|
||||||
|
// Create the KrnlIterateOp assiciated with this loop nest. The loops
|
||||||
|
// iteration will be created if the definition and the optimization
|
||||||
|
// operations associated with this loop nest have been emitted already.
|
||||||
|
void createIterateOp();
|
||||||
|
|
||||||
|
// Create the loop nest definition, optimization and iteration operations
|
||||||
|
// for a given operand of MemRef type. The loop nest has a depth equal to the
|
||||||
|
// rank of the MemRef operand. The lower bound of each loop is zero. The
|
||||||
|
// upper bound of each loop is given by the corresponding dimension of the
|
||||||
|
// MemRef operand.
|
||||||
|
void createDefineOptimizeAndIterateOp(
|
||||||
|
Value memRefOperand, bool withEmptyOptimization = true);
|
||||||
|
|
||||||
|
// Get the (original loop) induction variable associated with the given
|
||||||
|
// index. Use the index returned when pushing the bounds.
|
||||||
|
BlockArgument &getInductionVar(int originalLoopIndex);
|
||||||
|
|
||||||
|
// Get a reference to the code region of the optimization operation.
|
||||||
|
// This allows us to set the insertion point to the inner block of the
|
||||||
|
// loop nest optimization operation.
|
||||||
|
Block *getOptimizationBlock() { return optBlock; }
|
||||||
|
|
||||||
|
// Get a reference to the code region of the iteration operation.
|
||||||
|
// This allows us to set the insertion point to the inner block of the
|
||||||
|
// loop nest iteration operation.
|
||||||
|
Block *getIterateBlock() { return iterBlock; }
|
||||||
|
|
||||||
|
// Get original loop nest.
|
||||||
|
std::vector<Value> &getOriginalLoops() { return originalLoops; }
|
||||||
|
|
||||||
|
// Get optimized loop nest.
|
||||||
|
std::vector<Value> &getOptimizedLoops() { return optLoops; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Required for emitting operations.
|
||||||
|
ConversionPatternRewriter &rewriter;
|
||||||
|
Location loc;
|
||||||
|
int originalLoopNum;
|
||||||
|
|
||||||
|
// List of original, un-optimized loops.
|
||||||
|
std::vector<Value> originalLoops;
|
||||||
|
|
||||||
|
// List of optimized loops.
|
||||||
|
std::vector<Value> optLoops;
|
||||||
|
|
||||||
|
// List of lower-upper bound pairs needed by the KrnlIterateOp.
|
||||||
|
KrnlIterateOperandPack *pack;
|
||||||
|
|
||||||
|
// Number of lower-upper bound pairs pushed.
|
||||||
|
int pushCount;
|
||||||
|
|
||||||
|
// Flags that keep track of emitted operations.
|
||||||
|
bool createdDefineOp;
|
||||||
|
bool createdOptimizeOp;
|
||||||
|
bool createdIterateOp;
|
||||||
|
|
||||||
|
// Saved insertion point in the code region of the KrnlOptimizeLoopsOp.
|
||||||
|
Block *optBlock;
|
||||||
|
|
||||||
|
// Saved insertion point in the code region of the KrnlIterateOp.
|
||||||
|
Block *iterBlock;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
|
@ -90,25 +90,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
||||||
// or outputs. This decision affects only ONNX operations with optional
|
// or outputs. This decision affects only ONNX operations with optional
|
||||||
// arguments not ONNX operations with variadic operands.
|
// arguments not ONNX operations with variadic operands.
|
||||||
|
|
||||||
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
|
||||||
let summary = "ONNX general matrix multiply operation without bias.";
|
|
||||||
let description = [{
|
|
||||||
|
|
||||||
The "onnx.Gemm" generic matrix multiplication without bias.
|
|
||||||
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
|
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
|
|
||||||
DefaultValuedAttr<F32Attr, "1.0">:$alpha,
|
|
||||||
DefaultValuedAttr<F32Attr, "1.0">:$beta,
|
|
||||||
DefaultValuedAttr<I64Attr, "0">:$transA,
|
|
||||||
DefaultValuedAttr<I64Attr, "0">:$transB);
|
|
||||||
|
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
|
||||||
}
|
|
||||||
|
|
||||||
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
|
|
@ -24,12 +24,29 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::OpTrait::util;
|
using namespace mlir::OpTrait::util;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNX Helper functions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static size_t ArrayAttrSize(ArrayAttr a) { return a.size(); }
|
||||||
|
|
||||||
|
static size_t ArrayAttrSize(Optional<ArrayAttr> a) {
|
||||||
|
return a.getValue().size();
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
||||||
|
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t ArrayAttrIntVal(Optional<ArrayAttr> a, int i) {
|
||||||
|
return (a.getValue().getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Get reduction type
|
// Get reduction type
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
RankedTensorType getReductionOutputType(RankedTensorType operandTy,
|
RankedTensorType getReductionOutputType(
|
||||||
Optional<ArrayAttr> axesAttrs,
|
RankedTensorType operandTy, Optional<ArrayAttr> axesAttrs, APInt keepdims) {
|
||||||
APInt keepdims) {
|
|
||||||
int64_t rank = operandTy.getRank();
|
int64_t rank = operandTy.getRank();
|
||||||
|
|
||||||
SmallVector<int64_t, 4> axes;
|
SmallVector<int64_t, 4> axes;
|
||||||
|
@ -87,19 +104,18 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
void ONNXEntryPointOp::build(mlir::Builder *builder,
|
void ONNXEntryPointOp::build(mlir::Builder *builder,
|
||||||
mlir::OperationState &state, mlir::FuncOp function,
|
mlir::OperationState &state, mlir::FuncOp function, int numInputs,
|
||||||
int numInputs, int numOutputs) {
|
int numOutputs) {
|
||||||
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
|
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
|
||||||
builder->getSymbolRefAttr(function));
|
builder->getSymbolRefAttr(function));
|
||||||
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
|
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
|
||||||
builder->getI32IntegerAttr(numInputs));
|
builder->getI32IntegerAttr(numInputs));
|
||||||
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
|
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
|
||||||
builder->getI32IntegerAttr(numOutputs));
|
builder->getI32IntegerAttr(numOutputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
|
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
|
||||||
mlir::FuncOp &func, int numInputs,
|
mlir::FuncOp &func, int numInputs, int numOutputs) {
|
||||||
int numOutputs) {
|
|
||||||
mlir::OperationState state(location, "onnx.EntryPoint");
|
mlir::OperationState state(location, "onnx.EntryPoint");
|
||||||
Builder builder(location->getContext());
|
Builder builder(location->getContext());
|
||||||
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
|
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
|
||||||
|
@ -120,25 +136,19 @@ void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
// Tanh
|
// Tanh
|
||||||
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXTanhOp::inferShapes() {
|
void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sinh
|
// Sinh
|
||||||
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXSinhOp::inferShapes() {
|
void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Cosh
|
// Cosh
|
||||||
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXCoshOp::inferShapes() {
|
void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Cos
|
// Cos
|
||||||
|
@ -178,9 +188,7 @@ void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
// Relu
|
// Relu
|
||||||
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXReluOp::inferShapes() {
|
void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LeakyRelu
|
// LeakyRelu
|
||||||
|
@ -194,9 +202,7 @@ void ONNXLeakyReluOp::inferShapes() {
|
||||||
// Selu
|
// Selu
|
||||||
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSeluOp::inferShapes() {
|
void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Reciprocal
|
// Reciprocal
|
||||||
|
@ -234,17 +240,13 @@ void ONNXSoftsignOp::inferShapes() {
|
||||||
// Sqrt
|
// Sqrt
|
||||||
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSqrtOp::inferShapes() {
|
void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sign
|
// Sign
|
||||||
/// Infer the output shape of the ONNXSignOp. This method is required by
|
/// Infer the output shape of the ONNXSignOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSignOp::inferShapes() {
|
void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add
|
// Add
|
||||||
|
@ -404,12 +406,12 @@ void ONNXIdentityOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXMatMulOp::inferShapes() {
|
void ONNXMatMulOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!A().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!B().getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
auto lhsShape = lhsTy.getShape();
|
auto lhsShape = lhsTy.getShape();
|
||||||
|
@ -417,15 +419,14 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
|
|
||||||
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
|
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
|
||||||
// Multiplication by scalars is not allowed.
|
// Multiplication by scalars is not allowed.
|
||||||
emitError("Multiplication by scalar arguments not allowed.");
|
emitError("Multiplication by scalar arguments not allowed");
|
||||||
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
||||||
// Special case when both arrays are 1-dimensional and according to
|
// Special case when both arrays are 1-dimensional and according to
|
||||||
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
||||||
// need to be removed after the multiplication but cannot be removed if all
|
// need to be removed after the multiplication but cannot be removed if all
|
||||||
// sizes are 1.
|
// sizes are 1.
|
||||||
if (lhsShape[0] != -1 && rhsShape[0] != -1 &&
|
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
|
||||||
lhsShape[0] != rhsShape[0])
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
|
||||||
dims.emplace_back(1);
|
dims.emplace_back(1);
|
||||||
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
|
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
|
||||||
// If the first argument is 1-D, it is promoted to a matrix by prepending a
|
// If the first argument is 1-D, it is promoted to a matrix by prepending a
|
||||||
|
@ -440,7 +441,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
unsigned rhsRank = rhsShape.size();
|
unsigned rhsRank = rhsShape.size();
|
||||||
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||||
lhsShape[0] != rhsShape[rhsRank - 2])
|
lhsShape[0] != rhsShape[rhsRank - 2])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||||
dims.emplace_back(rhsShape[i]);
|
dims.emplace_back(rhsShape[i]);
|
||||||
|
@ -458,7 +459,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
unsigned lhsRank = lhsShape.size();
|
unsigned lhsRank = lhsShape.size();
|
||||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
||||||
lhsShape[lhsRank - 1] != rhsShape[0])
|
lhsShape[lhsRank - 1] != rhsShape[0])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
|
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
|
||||||
dims.emplace_back(lhsShape[i]);
|
dims.emplace_back(lhsShape[i]);
|
||||||
|
@ -472,7 +473,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
unsigned lhsRank = lhsShape.size();
|
unsigned lhsRank = lhsShape.size();
|
||||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
||||||
lhsShape[lhsRank - 1] != rhsShape[0])
|
lhsShape[lhsRank - 1] != rhsShape[0])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
|
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
|
||||||
dims.emplace_back(lhsShape[i]);
|
dims.emplace_back(lhsShape[i]);
|
||||||
|
@ -486,7 +487,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
unsigned rhsRank = rhsShape.size();
|
unsigned rhsRank = rhsShape.size();
|
||||||
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||||
lhsShape[1] != rhsShape[rhsRank - 2])
|
lhsShape[1] != rhsShape[rhsRank - 2])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||||
dims.emplace_back(rhsShape[i]);
|
dims.emplace_back(rhsShape[i]);
|
||||||
|
@ -502,7 +503,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
unsigned rhsRank = rhsShape.size();
|
unsigned rhsRank = rhsShape.size();
|
||||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||||
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
|
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
// Check and perform broadcasting for the shapes.
|
// Check and perform broadcasting for the shapes.
|
||||||
SmallVector<int64_t, 2> lhsBcastShape;
|
SmallVector<int64_t, 2> lhsBcastShape;
|
||||||
|
@ -512,7 +513,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||||
rhsBcastShape.emplace_back(rhsShape[i]);
|
rhsBcastShape.emplace_back(rhsShape[i]);
|
||||||
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
||||||
emitError("Broadcasted dimensions are incompatible.");
|
emitError("Broadcasted dimensions are incompatible");
|
||||||
|
|
||||||
dims.emplace_back(lhsShape[lhsRank - 2]);
|
dims.emplace_back(lhsShape[lhsRank - 2]);
|
||||||
dims.emplace_back(rhsShape[rhsRank - 1]);
|
dims.emplace_back(rhsShape[rhsRank - 1]);
|
||||||
|
@ -527,7 +528,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
|
|
||||||
// Check legality of matrix multiplication.
|
// Check legality of matrix multiplication.
|
||||||
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
|
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices");
|
||||||
|
|
||||||
if (rhsShape.size() > 1)
|
if (rhsShape.size() > 1)
|
||||||
dims.emplace_back(rhsShape[1]);
|
dims.emplace_back(rhsShape[1]);
|
||||||
|
@ -541,14 +542,14 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
// Gemm
|
// Gemm
|
||||||
|
|
||||||
void ONNXGemmOp::inferShapes() {
|
void ONNXGemmOp::inferShapes() {
|
||||||
|
bool hasBias = !C().getType().isa<NoneType>();
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!A().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>() ||
|
!B().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(2).getType().isa<RankedTensorType>())
|
(hasBias && !C().getType().isa<RankedTensorType>()))
|
||||||
return;
|
return;
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||||
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
|
|
||||||
|
|
||||||
int64_t M, N, K_A, K_B;
|
int64_t M, N, K_A, K_B;
|
||||||
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
||||||
|
@ -557,44 +558,21 @@ void ONNXGemmOp::inferShapes() {
|
||||||
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
||||||
|
|
||||||
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
||||||
emitError("Tensor shapes mismatched.");
|
emitError("Tensor shapes mismatched");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether bias is unidirectional broadcasting or not.
|
if (hasBias) {
|
||||||
auto shape = biasTy.getShape();
|
// Check whether bias is unidirectional broadcasting or not.
|
||||||
int rank = shape.size();
|
auto biasTy = C().getType().cast<RankedTensorType>();
|
||||||
if ((rank > 2) ||
|
auto shape = biasTy.getShape();
|
||||||
(rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] &&
|
int rank = shape.size();
|
||||||
shape[rank - 1] != 1) ||
|
if ((rank > 2) ||
|
||||||
(rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] &&
|
(rank >= 1 && shape[rank - 1] != -1 && N != -1 &&
|
||||||
shape[rank - 2] != 1)) {
|
N != shape[rank - 1] && shape[rank - 1] != 1) ||
|
||||||
emitError("Bias shape mismatched.");
|
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
|
||||||
}
|
M != shape[rank - 2] && shape[rank - 2] != 1)) {
|
||||||
|
emitError("Bias shape mismatched");
|
||||||
SmallVector<int64_t, 2> dims;
|
}
|
||||||
dims.emplace_back(M);
|
|
||||||
dims.emplace_back(N);
|
|
||||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// GemmNoBias
|
|
||||||
|
|
||||||
void ONNXGemmNoBiasOp::inferShapes() {
|
|
||||||
// Cannot infer shape if no shape exists.
|
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
|
||||||
return;
|
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
||||||
|
|
||||||
int64_t M, N, K_A, K_B;
|
|
||||||
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
|
||||||
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
|
|
||||||
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
|
|
||||||
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
|
||||||
|
|
||||||
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
|
||||||
emitError("Tensor shapes mismatched.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
|
@ -606,50 +584,50 @@ void ONNXGemmNoBiasOp::inferShapes() {
|
||||||
/// BatchNormalizationTestMode
|
/// BatchNormalizationTestMode
|
||||||
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!X().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>() ||
|
!scale().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(2).getType().isa<RankedTensorType>() ||
|
!B().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(3).getType().isa<RankedTensorType>() ||
|
!mean().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(4).getType().isa<RankedTensorType>())
|
!var().getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto input = getOperand(0).getType().cast<RankedTensorType>();
|
auto inputTensorTy = X().getType().cast<RankedTensorType>();
|
||||||
auto scale = getOperand(1).getType().cast<RankedTensorType>();
|
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
|
||||||
auto bias = getOperand(2).getType().cast<RankedTensorType>();
|
auto biasTensorTy = B().getType().cast<RankedTensorType>();
|
||||||
auto mean = getOperand(3).getType().cast<RankedTensorType>();
|
auto meanTensorTy = mean().getType().cast<RankedTensorType>();
|
||||||
auto variance = getOperand(4).getType().cast<RankedTensorType>();
|
auto varianceTensorTy = var().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
// Check whether the shapes of scale, bias, mean and variance are valid.
|
// Check whether the shapes of scale, bias, mean and variance are valid.
|
||||||
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
||||||
// In case of N, C is assumed to be 1.
|
// In case of N, C is assumed to be 1.
|
||||||
// Shapes of scale, bias, mean and variance must be C.
|
// Shapes of scale, bias, mean and variance must be C.
|
||||||
int64_t c = -1;
|
int64_t c = -1;
|
||||||
if (input.getShape().size() == 1) {
|
if (inputTensorTy.getShape().size() == 1) {
|
||||||
c = 1;
|
c = 1;
|
||||||
} else if (input.getShape().size() > 2) {
|
} else if (inputTensorTy.getShape().size() > 2) {
|
||||||
c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1;
|
c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1;
|
||||||
} else {
|
} else {
|
||||||
emitError("Wrong rank for the input.");
|
emitError("Wrong rank for the input");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (c != -1) {
|
if (c != -1) {
|
||||||
auto s = scale.getShape();
|
auto s = scaleTensorTy.getShape();
|
||||||
auto b = bias.getShape();
|
auto b = biasTensorTy.getShape();
|
||||||
auto m = mean.getShape();
|
auto m = meanTensorTy.getShape();
|
||||||
auto v = variance.getShape();
|
auto v = varianceTensorTy.getShape();
|
||||||
|
|
||||||
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
|
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
|
||||||
emitError("Wrong rank for the scale.");
|
emitError("Wrong rank for the scale");
|
||||||
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
|
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
|
||||||
emitError("Wrong rank for the bias.");
|
emitError("Wrong rank for the bias");
|
||||||
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
|
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
|
||||||
emitError("Wrong rank for the mean.");
|
emitError("Wrong rank for the mean");
|
||||||
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
|
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
|
||||||
emitError("Wrong rank for the variance.");
|
emitError("Wrong rank for the variance");
|
||||||
}
|
}
|
||||||
|
|
||||||
// The output tensor of the same shape as the input.
|
// The output tensor of the same shape as the input.
|
||||||
getResult().setType(getOperand(0).getType());
|
getResult().setType(X().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
|
@ -662,21 +640,21 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXReshapeOp::inferShapes() {
|
void ONNXReshapeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape tensor is specified.
|
// Cannot infer shape if no shape tensor is specified.
|
||||||
if (!getOperand(1).getType().isa<RankedTensorType>())
|
if (!shape().getType().isa<RankedTensorType>())
|
||||||
emitError("Shape tensor not ranked.");
|
emitError("Shape tensor not ranked");
|
||||||
|
|
||||||
auto inputTensorTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto inputTensorTy = data().getType().cast<RankedTensorType>();
|
||||||
auto shapeTensorTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
// Only rank 1 shape tensors are supported.
|
// Only rank 1 shape tensors are supported.
|
||||||
if (shapeTensorTy.getShape().size() != 1)
|
if (shapeTensorTy.getShape().size() != 1)
|
||||||
emitError("Shape tensor must have rank one.");
|
emitError("Shape tensor must have rank one");
|
||||||
|
|
||||||
int64_t outputRank = shapeTensorTy.getShape()[0];
|
int64_t outputRank = shapeTensorTy.getShape()[0];
|
||||||
|
|
||||||
// Shape tensor must have constant shape.
|
// Shape tensor must have constant shape.
|
||||||
if (outputRank < 0)
|
if (outputRank < 0)
|
||||||
emitError("Shape tensor must have constant shape.");
|
emitError("Shape tensor must have constant shape");
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
for (int i = 0; i < outputRank; ++i)
|
for (int i = 0; i < outputRank; ++i)
|
||||||
|
@ -692,12 +670,12 @@ void ONNXReshapeOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXTransposeOp::inferShapes() {
|
void ONNXTransposeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand().getType().isa<RankedTensorType>())
|
if (!data().getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
// Naive transposition which handles the default case of
|
// Naive transposition which handles the default case of
|
||||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = data().getType().cast<RankedTensorType>();
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
auto permutation = ONNXTransposeOp::permAttr();
|
auto permutation = ONNXTransposeOp::permAttr();
|
||||||
if (permutation) {
|
if (permutation) {
|
||||||
|
@ -713,14 +691,13 @@ void ONNXTransposeOp::inferShapes() {
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// ReduceMax
|
// ReduceMax
|
||||||
|
|
||||||
void ONNXReduceMaxOp::inferShapes() {
|
void ONNXReduceMaxOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked.");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -734,7 +711,7 @@ void ONNXReduceMaxOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXReduceMinOp::inferShapes() {
|
void ONNXReduceMinOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked.");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -748,7 +725,7 @@ void ONNXReduceMinOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXReduceProdOp::inferShapes() {
|
void ONNXReduceProdOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked.");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -762,7 +739,7 @@ void ONNXReduceProdOp::inferShapes() {
|
||||||
|
|
||||||
void ONNXReduceSumOp::inferShapes() {
|
void ONNXReduceSumOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked.");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -781,30 +758,31 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// W: (M x C/group x k1 x k2 x ... x kn)
|
// W: (M x C/group x k1 x k2 x ... x kn)
|
||||||
|
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!X().getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!W().getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto dataTy = X().getType().cast<RankedTensorType>();
|
||||||
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto weightTy = W().getType().cast<RankedTensorType>();
|
||||||
auto dataShape = dataTy.getShape();
|
auto dataShape = dataTy.getShape();
|
||||||
auto weightShape = weightTy.getShape();
|
auto weightShape = weightTy.getShape();
|
||||||
|
|
||||||
// Lowest supported convolution is a one dimensional convolution.
|
// Lowest supported convolution is a one dimensional convolution.
|
||||||
if (dataShape.size() < 3)
|
if (dataShape.size() < 3)
|
||||||
emitError("Data input shape must be at least (NxCxD1).");
|
emitError("Data input shape must be at least (NxCxD1)");
|
||||||
|
|
||||||
// Check that shape of weight and data have same length.
|
// Check that shape of weight and data have same length.
|
||||||
if (dataShape.size() != weightShape.size())
|
if (dataShape.size() != weightShape.size())
|
||||||
emitError("Weight size not compatible with data size.");
|
emitError("Weight size not compatible with data size");
|
||||||
|
|
||||||
// Required attribute auto_pad defaults to NOTSET.
|
// Required attribute auto_pad defaults to NOTSET.
|
||||||
auto autoPad = auto_pad();
|
auto autoPad = auto_pad();
|
||||||
// Group is a required attribute and should have default value of 1.
|
// Group is a required attribute and should have default value of 1.
|
||||||
int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
|
int64_t group =
|
||||||
|
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
|
||||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||||
if (dataShape[1] != (weightShape[1] * group))
|
if (dataShape[1] != (weightShape[1] * group))
|
||||||
emitError("Channel dimension mismatch.");
|
emitError("Channel dimension mismatch");
|
||||||
|
|
||||||
// Note: the value of the group attribut only impacts the way the
|
// Note: the value of the group attribut only impacts the way the
|
||||||
// computation is carried out and not the actual output size.
|
// computation is carried out and not the actual output size.
|
||||||
|
@ -834,11 +812,10 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// argument.
|
// argument.
|
||||||
SmallVector<int64_t, 2> kernelDims;
|
SmallVector<int64_t, 2> kernelDims;
|
||||||
if (auto kernelShape = kernel_shapeAttr()) {
|
if (auto kernelShape = kernel_shapeAttr()) {
|
||||||
if (kernelShape.getValue().size() != nDims)
|
if (ArrayAttrSize(kernelShape) != nDims)
|
||||||
emitError("kernel_shape length incompatible with spatial dimensions.");
|
emitError("kernel_shape length incompatible with spatial dimensions");
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
kernelDims.emplace_back(
|
kernelDims.emplace_back(ArrayAttrIntVal(kernelShape, i));
|
||||||
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
|
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
kernelDims.emplace_back(weightShape[i + 2]);
|
kernelDims.emplace_back(weightShape[i + 2]);
|
||||||
|
@ -856,11 +833,11 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// From a dimensionality perspective the kernel size becomes the dilated
|
// From a dimensionality perspective the kernel size becomes the dilated
|
||||||
// kernel size.
|
// kernel size.
|
||||||
if (auto dilations = dilationsAttr()) {
|
if (auto dilations = dilationsAttr()) {
|
||||||
if (dilations.getValue().size() != nDims)
|
if (ArrayAttrSize(dilations) != nDims)
|
||||||
emitError("dilations length incompatible with spatial dimensions.");
|
emitError("dilations length incompatible with spatial dimensions");
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
kernelDims[i] = (kernelDims[i] + 1) *
|
kernelDims[i] =
|
||||||
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1;
|
(kernelDims[i] + 1) * ArrayAttrIntVal(dilations, i) - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subtract kernel dimensions from input data dimensions.
|
// Subtract kernel dimensions from input data dimensions.
|
||||||
|
@ -873,16 +850,14 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// present then pads is considered to be all zeros (no padding).
|
// present then pads is considered to be all zeros (no padding).
|
||||||
if (auto pads = padsAttr()) {
|
if (auto pads = padsAttr()) {
|
||||||
// pads consists of two entries for each spatial axis.
|
// pads consists of two entries for each spatial axis.
|
||||||
if (pads.getValue().size() != 2 * nDims)
|
if (ArrayAttrSize(pads) != 2 * nDims)
|
||||||
emitError("pads size is not twice the spatial size.");
|
emitError("pads size is not twice the spatial size");
|
||||||
|
|
||||||
for (int i = 0; i < nDims; ++i) {
|
for (int i = 0; i < nDims; ++i) {
|
||||||
// Padding for beginning of axis.
|
// Padding for beginning of axis.
|
||||||
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
outSpatialDims[i] += ArrayAttrIntVal(pads, i);
|
||||||
outSpatialDims[i] += p;
|
|
||||||
// Padding for end of axis.
|
// Padding for end of axis.
|
||||||
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
outSpatialDims[i] += ArrayAttrIntVal(pads, i + nDims);
|
||||||
outSpatialDims[i] += p;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||||
|
@ -898,16 +873,15 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
} else if (autoPad == "VALID") {
|
} else if (autoPad == "VALID") {
|
||||||
// No padding
|
// No padding
|
||||||
} else {
|
} else {
|
||||||
emitError("Unexpected attribute value for auto_pad.");
|
emitError("Unexpected attribute value for auto_pad");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strides
|
// Strides
|
||||||
if (auto strides = ONNXConvNoBiasOp::stridesAttr()) {
|
if (auto strides = ONNXConvNoBiasOp::stridesAttr()) {
|
||||||
if (strides.getValue().size() != nDims)
|
if (ArrayAttrSize(strides) != nDims)
|
||||||
emitError("strides length incompatible with spatial dimensions.");
|
emitError("strides length incompatible with spatial dimensions");
|
||||||
for (int i = 0; i < nDims; ++i) {
|
for (int i = 0; i < nDims; ++i) {
|
||||||
int64_t stride =
|
int64_t stride = ArrayAttrIntVal(strides, i);
|
||||||
strides.getValue()[i].cast<IntegerAttr>().getInt();
|
|
||||||
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -922,112 +896,108 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// MaxPoolSingleOut
|
// MaxPoolSingleOut
|
||||||
|
// Infer shape attributes output:
|
||||||
|
// - auto_pad set to NOTSET;
|
||||||
|
// - dilations, strides: set to 1 if not defined by user;
|
||||||
|
// - pads: set to proper value, 0 if not defined by user.
|
||||||
|
|
||||||
void ONNXMaxPoolSingleOutOp::inferShapes() {
|
void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!X().getType().isa<RankedTensorType>())
|
if (!X().getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
auto builder = mlir::Builder(this->getContext());
|
||||||
|
|
||||||
// 1) get shape of input
|
// 1) Get shape of input.
|
||||||
auto xTy = X().getType().cast<RankedTensorType>();
|
auto xTy = X().getType().cast<RankedTensorType>();
|
||||||
auto xShape = xTy.getShape();
|
auto xShape = xTy.getShape();
|
||||||
auto xRank = xShape.size();
|
auto xRank = xShape.size();
|
||||||
|
|
||||||
// 2) analyse parameters
|
// 2) Analyse parameters. Get kernel sizes from kernel_shape attribute.
|
||||||
// get kernel sizes from kernel_shape attribute
|
|
||||||
auto kernelShape = kernel_shape();
|
auto kernelShape = kernel_shape();
|
||||||
if (!kernelShape)
|
if (!kernelShape)
|
||||||
emitError("kernel_shape is a mandatory attribute for which there is no default.");
|
emitError(
|
||||||
auto kernelShapeArray = kernelShape.getValue();
|
"kernel_shape is a mandatory attribute for which there is no default");
|
||||||
auto kernelRank = kernelShape.size();
|
auto kernelRank = ArrayAttrSize(kernelShape);
|
||||||
if (kernelRank > xRank)
|
if (kernelRank > xRank)
|
||||||
emitError("kernel_shape spatial dimension is too large.");
|
emitError("kernel_shape spatial dimension is too large");
|
||||||
auto kernelOffset = xRank - kernelRank;
|
auto kernelOffset = xRank - kernelRank;
|
||||||
|
|
||||||
// ceil mode
|
// Ceil mode.
|
||||||
auto ceilMode = ceil_mode().getSExtValue();
|
auto ceilMode = ceil_mode().getSExtValue();
|
||||||
|
|
||||||
// dilatation
|
// Dilatation.
|
||||||
SmallVector<int64_t, 4> actualDilations;
|
|
||||||
auto dilationsOpt = dilations();
|
auto dilationsOpt = dilations();
|
||||||
if (dilationsOpt.hasValue()) {
|
if (dilationsOpt.hasValue()) {
|
||||||
auto dilationsArray = dilationsOpt.getValue().getValue(); // opt -> attr -> array
|
if (ArrayAttrSize(dilationsOpt) != kernelRank)
|
||||||
if (dilationsArray.size() != kernelRank)
|
emitError("dialation rank is not the same as the spatial rank");
|
||||||
emitError("dialation rank is not the same as the spatial rank.");
|
// Test values.
|
||||||
// fill in the actual values
|
|
||||||
for (int i = 0; i < kernelRank; ++i) {
|
for (int i = 0; i < kernelRank; ++i) {
|
||||||
int64_t d = (dilationsArray[i]).cast<IntegerAttr>().getInt();
|
if (ArrayAttrIntVal(dilationsOpt, i) < 1)
|
||||||
if (d < 1)
|
emitError("dialation value must be nonzero positive");
|
||||||
emitError("dialation value must be nonzero positive.");
|
|
||||||
actualDilations.emplace_back(d);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for(int i=0; i < kernelRank; ++i) {
|
// Default dilatation is needed.
|
||||||
actualDilations.emplace_back(1);
|
SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
|
||||||
}
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
||||||
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
||||||
|
auto defaultAttr = builder.getI64ArrayAttr(defaultRefs);
|
||||||
|
dilationsAttr(defaultAttr);
|
||||||
|
dilationsOpt = dilations();
|
||||||
}
|
}
|
||||||
|
|
||||||
// storage order
|
// Storage order.
|
||||||
|
auto storageOrder = storage_order().getSExtValue();
|
||||||
|
if (storageOrder != 0)
|
||||||
|
emitError("column major storage order not supported at this time");
|
||||||
|
|
||||||
// strides
|
// Strides.
|
||||||
SmallVector<int64_t, 4> actualStrides;
|
|
||||||
auto stridesOpt = strides();
|
auto stridesOpt = strides();
|
||||||
if (stridesOpt.hasValue()) {
|
if (stridesOpt.hasValue()) {
|
||||||
auto stridesArray = stridesOpt.getValue().getValue();
|
if (ArrayAttrSize(stridesOpt) != kernelRank)
|
||||||
if (stridesArray.size() != kernelRank)
|
emitError("strides rank is not the same as the spatial rank");
|
||||||
emitError("strides rank is not the same as the spatial rank.");
|
// Check values.
|
||||||
// fill in the actual values
|
|
||||||
for (int i = 0; i < kernelRank; ++i) {
|
for (int i = 0; i < kernelRank; ++i) {
|
||||||
int64_t s = (stridesArray[i]).cast<IntegerAttr>().getInt();
|
if (ArrayAttrIntVal(stridesOpt, i) < 1)
|
||||||
if (s < 1)
|
emitError("strides value must be nonzero positive");
|
||||||
emitError("strides value must be nonzero positive.");
|
|
||||||
actualStrides.emplace_back(s);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for(int i=0; i < kernelRank; ++i) {
|
SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
|
||||||
actualStrides.emplace_back(1);
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
||||||
}
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
||||||
|
auto defaultAttr = builder.getI64ArrayAttr(defaultRefs);
|
||||||
|
stridesAttr(defaultAttr);
|
||||||
|
stridesOpt = strides();
|
||||||
}
|
}
|
||||||
|
|
||||||
// now try to find padding, getting auto_pad attribute first
|
// Now try to find padding, getting auto_pad attribute first.
|
||||||
auto autoPad = auto_pad();
|
auto autoPad = auto_pad();
|
||||||
// and then investigate the various different cases
|
// And then investigate the various different cases.
|
||||||
SmallVector<int64_t, 4> actualPads;
|
SmallVector<int64_t, 4> actualPads(2 * kernelRank, 0);
|
||||||
auto defaultPads = false;
|
|
||||||
if (autoPad == "NOTSET") {
|
if (autoPad == "NOTSET") {
|
||||||
auto padsOpt = pads();
|
auto padsOpt = pads();
|
||||||
if (padsOpt.hasValue()) {
|
if (padsOpt.hasValue()) {
|
||||||
auto padsArray = padsOpt.getValue().getValue();
|
// Pads consists of two entries for each spatial axis.
|
||||||
// pads consists of two entries for each spatial axis.
|
if (ArrayAttrSize(padsOpt) != 2 * kernelRank)
|
||||||
if (padsArray.size() != 2 * kernelRank)
|
emitError("pads rank is not twice the spatial rank");
|
||||||
emitError("pads rank is not twice the spatial rank.");
|
// Check values
|
||||||
// fill in the actual values
|
for (int i = 0; i < 2 * kernelRank; ++i) {
|
||||||
for (int i = 0; i < 2*kernelRank; ++i) {
|
int64_t p = ArrayAttrIntVal(padsOpt, i);
|
||||||
int64_t p = (padsArray[i]).cast<IntegerAttr>().getInt();
|
|
||||||
if (p < 0)
|
if (p < 0)
|
||||||
emitError("pads value must be nonnegative.");
|
emitError("pads value must be nonnegative");
|
||||||
actualPads.emplace_back(p);
|
actualPads[i] = p;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// pads are not defined, default to value 0
|
|
||||||
defaultPads = true;
|
|
||||||
}
|
}
|
||||||
} else if (autoPad == "VALID") {
|
|
||||||
defaultPads = true;
|
|
||||||
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||||
// init pad with zero
|
for (int i = 0; i < kernelRank; ++i) {
|
||||||
for(int i=0; i<2*kernelRank; ++i) {
|
auto inputSpatialShape = xShape[kernelOffset + i];
|
||||||
actualPads.emplace_back(0);
|
auto kernelSpatialShape = ArrayAttrIntVal(kernelShape, i);
|
||||||
}
|
auto dilations = ArrayAttrIntVal(dilationsOpt, i);
|
||||||
for(int i=0; i<kernelRank; ++i) {
|
auto strideSpatialShape = ArrayAttrIntVal(stridesOpt, i);
|
||||||
auto inputSpatialShape = xShape[kernelOffset + i];
|
int64_t outputSpatialShape =
|
||||||
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt();
|
ceil((1.0 * inputSpatialShape) / (1.0 * strideSpatialShape));
|
||||||
auto dilations = actualDilations[i];
|
|
||||||
auto strideSpatialShape = actualStrides[i];
|
|
||||||
int64_t outputSpatialShape = ceil((1.0 * inputSpatialShape) /
|
|
||||||
(1.0 * strideSpatialShape));
|
|
||||||
auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +
|
auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +
|
||||||
((kernelSpatialShape - 1) * dilations + 1) - inputSpatialShape;
|
((kernelSpatialShape - 1) * dilations + 1) -
|
||||||
|
inputSpatialShape;
|
||||||
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
|
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
|
||||||
if (sumOfPad % 2 != 0) {
|
if (sumOfPad % 2 != 0) {
|
||||||
if (autoPad == "SAME_UPPER") {
|
if (autoPad == "SAME_UPPER") {
|
||||||
|
@ -1037,29 +1007,29 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if (autoPad != "VALID") {
|
||||||
emitError("auto_pad of unknown / unsupported value.");
|
emitError("auto_pad of unknown / unsupported value");
|
||||||
}
|
}
|
||||||
// handle case where default pad values must be used
|
// Set pads values in attributes.
|
||||||
if (defaultPads) {
|
{
|
||||||
for(int i=0; i<2*kernelRank; ++i) {
|
ArrayRef<int64_t> defaultRefs(actualPads);
|
||||||
actualPads.emplace_back(0);
|
auto defaultAttr = builder.getI64ArrayAttr(defaultRefs);
|
||||||
}
|
padsAttr(defaultAttr);
|
||||||
|
auto defaultAutoPadAttr = builder.getStringAttr("NOTSET");
|
||||||
|
auto_padAttr(defaultAutoPadAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize output shape
|
// Initialize output shape.
|
||||||
SmallVector<int64_t, 4> yShape(xShape.begin(), xShape.end());
|
SmallVector<int64_t, 4> yShape(xShape.begin(), xShape.end());
|
||||||
// for all kernel dimensions
|
// Process for all kernel dimensions.
|
||||||
for(int i=0; i<kernelRank; ++i) {
|
for (int i = 0; i < kernelRank; ++i) {
|
||||||
auto inputSpatialShape = xShape[kernelOffset + i];
|
auto inputSpatialShape = xShape[kernelOffset + i];
|
||||||
auto padShape = actualPads[i] + actualPads[kernelRank+i];
|
auto padShape = actualPads[i] + actualPads[kernelRank + i];
|
||||||
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt();
|
auto kernelSpatialShape = ArrayAttrIntVal(kernelShape, i);
|
||||||
auto dilations = actualDilations[i];
|
auto dilations = ArrayAttrIntVal(dilationsOpt, i);
|
||||||
auto strideSpatialShape = actualStrides[i];
|
auto strideSpatialShape = ArrayAttrIntVal(stridesOpt, i);
|
||||||
///output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] -
|
|
||||||
// ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1)
|
|
||||||
double numerator = inputSpatialShape + padShape -
|
double numerator = inputSpatialShape + padShape -
|
||||||
((kernelSpatialShape - 1) * dilations + 1);
|
((kernelSpatialShape - 1) * dilations + 1);
|
||||||
double denominator = strideSpatialShape;
|
double denominator = strideSpatialShape;
|
||||||
int64_t res;
|
int64_t res;
|
||||||
if (ceilMode) {
|
if (ceilMode) {
|
||||||
|
@ -1069,7 +1039,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
}
|
}
|
||||||
yShape[kernelOffset + i] = res;
|
yShape[kernelOffset + i] = res;
|
||||||
}
|
}
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = X().getType().cast<RankedTensorType>();
|
||||||
getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1152,10 +1122,10 @@ void ONNXPadConstantValuePadOp::inferShapes(){
|
||||||
// Unsqueeze
|
// Unsqueeze
|
||||||
|
|
||||||
void ONNXUnsqueezeOp::inferShapes() {
|
void ONNXUnsqueezeOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>())
|
if (!data().getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
auto operandTy = data().getType().cast<RankedTensorType>();
|
||||||
int inRank = operandTy.getRank();
|
int inRank = operandTy.getRank();
|
||||||
|
|
||||||
ArrayAttr axisAttrs = axesAttr();
|
ArrayAttr axisAttrs = axesAttr();
|
||||||
|
@ -1171,10 +1141,10 @@ void ONNXUnsqueezeOp::inferShapes() {
|
||||||
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
||||||
axes.emplace_back(axis);
|
axes.emplace_back(axis);
|
||||||
else
|
else
|
||||||
emitError("Duplicated axes.");
|
emitError("Duplicated axes");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
emitError("Axes attribute is required.");
|
emitError("Axes attribute is required");
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int64_t, 4> dims;
|
SmallVector<int64_t, 4> dims;
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -127,6 +127,10 @@ int main(int argc, char *argv[]) {
|
||||||
|
|
||||||
if (emissionTarget >= EmitMLIR) {
|
if (emissionTarget >= EmitMLIR) {
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
|
// An additional pass of canonicalization is helpful because lowering
|
||||||
|
// from ONNX dialect to Standard dialect exposes additional canonicalization
|
||||||
|
// oppertunities.
|
||||||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addPass(mlir::createLowerKrnlPass());
|
pm.addPass(mlir::createLowerKrnlPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,11 @@ void ONNXAddOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<MulAddToGemmOptPattern>(context);
|
results.insert<MulAddToGemmOptPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ONNXGemmOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
|
results.insert<FuseGemmFollowedByAddition>(context);
|
||||||
|
}
|
||||||
/// on the ONNXIdentityOp.
|
/// on the ONNXIdentityOp.
|
||||||
void ONNXIdentityOp::getCanonicalizationPatterns(
|
void ONNXIdentityOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
|
|
|
@ -26,6 +26,7 @@ include "dialect/onnx/onnx.td"
|
||||||
|
|
||||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||||
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
||||||
|
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pattern-Match and Rewrite
|
// Pattern-Match and Rewrite
|
||||||
|
@ -41,6 +42,11 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
||||||
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
|
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
|
||||||
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
|
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
|
||||||
|
|
||||||
|
// onnx.add(onnx.Gemm(%X, %Y, None), %Z) = onnx.Gemm(%X, %Y, %Z)
|
||||||
|
def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none, $alpha, $beta, $transA, $transB), $bias),
|
||||||
|
(ONNXGemmOp $m1, $m2, $bias, $alpha, $beta, $transA, $transB),
|
||||||
|
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2), (HasNoneType $none)]>;
|
||||||
|
|
||||||
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
||||||
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
||||||
(replaceWithValue $arg)>;
|
(replaceWithValue $arg)>;
|
||||||
|
|
|
@ -118,7 +118,6 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Identity" &&
|
op->getName().getStringRef() != "onnx.Identity" &&
|
||||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||||
op->getName().getStringRef() != "onnx.GemmNoBias" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||||
op->getName().getStringRef() != "onnx.Transpose" &&
|
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||||
op->getName().getStringRef() != "onnx.ReduceMax" &&
|
op->getName().getStringRef() != "onnx.ReduceMax" &&
|
||||||
|
|
|
@ -101,3 +101,14 @@ func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>
|
||||||
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
|
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
|
||||||
// CHECK-NEXT: return %1 : tensor<*xf32>
|
// CHECK-NEXT: return %1 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//CHECK-LABEL: @test_gemm_add_fusion(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> {
|
||||||
|
func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
|
||||||
|
return %1 : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32>
|
||||||
|
// return [[GEMM]] : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -806,35 +806,6 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
|
|
||||||
%0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32>
|
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
|
||||||
|
|
||||||
// CHECK-LABEL: test_gemm_no_bias
|
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
|
|
||||||
// CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
|
|
||||||
// CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
|
|
||||||
// CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
|
|
||||||
// CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
|
|
||||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
|
|
||||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) {
|
|
||||||
// CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32>
|
|
||||||
// CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
|
|
||||||
// CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
|
||||||
// CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
|
|
||||||
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
|
|
||||||
// CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
|
||||||
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
|
|
||||||
// CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: return [[RES]] : memref<10x10xf32>
|
|
||||||
// CHECK: }
|
|
||||||
}
|
|
||||||
|
|
||||||
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
|
@ -6,7 +6,7 @@ func @test_default_maxpoolsingleout(%arg0 : tensor<5x5x32x32xf32>) -> tensor<*xf
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: test_default_maxpoolsingleout
|
// CHECK-LABEL: test_default_maxpoolsingleout
|
||||||
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "VALID", ceil_mode = 0 : i64, kernel_shape = [3, 3], pads = [1, 1, 1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32>
|
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x5x30x30xf32>
|
// CHECK: return [[RES]] : tensor<5x5x30x30xf32>
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ func @test_default_maxpoolsingleout_defpad(%arg0 : tensor<5x5x32x32xf32>) -> ten
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: test_default_maxpoolsingleout_defpad
|
// CHECK-LABEL: test_default_maxpoolsingleout_defpad
|
||||||
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [3, 3]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32>
|
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x5x30x30xf32>
|
// CHECK: return [[RES]] : tensor<5x5x30x30xf32>
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ func @test_default_maxpoolsingleout_pad(%arg0 : tensor<5x5x32x32xf32>) -> tensor
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: test_default_maxpoolsingleout_pad
|
// CHECK-LABEL: test_default_maxpoolsingleout_pad
|
||||||
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [3, 3], pads = [1, 1, 1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32>
|
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x5x32x32xf32>
|
// CHECK: return [[RES]] : tensor<5x5x32x32xf32>
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ func @test_default_maxpoolsingleout_pad_nonunif(%arg0 : tensor<5x5x32x32xf32>) -
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: test_default_maxpoolsingleout_pad_nonunif
|
// CHECK-LABEL: test_default_maxpoolsingleout_pad_nonunif
|
||||||
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [2, 1, 1, 0]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x31x31xf32>
|
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [5, 3], pads = [2, 1, 1, 0], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x31x31xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x5x31x31xf32>
|
// CHECK: return [[RES]] : tensor<5x5x31x31xf32>
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ func @test_default_maxpoolsingleout_strides(%arg0 : tensor<5x5x32x32xf32>) -> te
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: test_default_maxpoolsingleout_strides
|
// CHECK-LABEL: test_default_maxpoolsingleout_strides
|
||||||
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x16x16xf32>
|
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x16x16xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x5x16x16xf32>
|
// CHECK: return [[RES]] : tensor<5x5x16x16xf32>
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ func @test_default_maxpoolsingleout_strides_nonunifpad(%arg0 : tensor<5x5x30x32x
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: test_default_maxpoolsingleout_strides_nonunifpad
|
// CHECK-LABEL: test_default_maxpoolsingleout_strides_nonunifpad
|
||||||
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32>
|
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x5x15x16xf32>
|
// CHECK: return [[RES]] : tensor<5x5x15x16xf32>
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ func @test_default_maxpoolsingleout_strides_nonunifpad_ceil(%arg0 : tensor<5x5x3
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: test_default_maxpoolsingleout_strides_nonunifpad_ceil
|
// CHECK-LABEL: test_default_maxpoolsingleout_strides_nonunifpad_ceil
|
||||||
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 1 : i64, kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x16x16xf32>
|
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 1 : i64, dilations = [1, 1], kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x16x16xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x5x16x16xf32>
|
// CHECK: return [[RES]] : tensor<5x5x16x16xf32>
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ func @test_default_maxpoolsingleout_strides_dilatation(%arg0 : tensor<5x5x8x8xf3
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: test_default_maxpoolsingleout_strides_dilatation
|
// CHECK-LABEL: test_default_maxpoolsingleout_strides_dilatation
|
||||||
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [2, 2], kernel_shape = [2, 2], strides = [3, 3]} : (tensor<5x5x8x8xf32>) -> tensor<5x5x2x2xf32>
|
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [2, 2], kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [3, 3]} : (tensor<5x5x8x8xf32>) -> tensor<5x5x2x2xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x5x2x2xf32>
|
// CHECK: return [[RES]] : tensor<5x5x2x2xf32>
|
||||||
|
|
||||||
/// Test the default behavior of Max Pool with dilatation
|
/// Test the default behavior of Max Pool with dilatation
|
||||||
|
@ -85,7 +85,7 @@ func @test_default_maxpoolsingleout_upper(%arg0 : tensor<5x5x16x13xf32>) -> tens
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: test_default_maxpoolsingleout_upper
|
// CHECK-LABEL: test_default_maxpoolsingleout_upper
|
||||||
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "SAME_UPPER", ceil_mode = 0 : i64, kernel_shape = [4, 4], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32>
|
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [4, 4], pads = [0, 1, 0, 2], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x5x4x4xf32>
|
// CHECK: return [[RES]] : tensor<5x5x4x4xf32>
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,6 +95,6 @@ func @test_default_maxpoolsingleout_lower(%arg0 : tensor<5x5x16x13xf32>) -> tens
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: test_default_maxpoolsingleout_lower
|
// CHECK-LABEL: test_default_maxpoolsingleout_lower
|
||||||
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "SAME_LOWER", ceil_mode = 0 : i64, kernel_shape = [4, 4], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32>
|
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [4, 4], pads = [0, 2, 0, 1], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x5x4x4xf32>
|
// CHECK: return [[RES]] : tensor<5x5x4x4xf32>
|
||||||
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 1439eab5542c625bb3da49860f0cd68c3eafdc18
|
Subproject commit 553df22c67bee5f0fe6599cff60f1afc6748c635
|
Loading…
Reference in New Issue