Merge remote-tracking branch 'upstream/master' into shapeinference-pad

This commit is contained in:
chentong 2020-02-25 15:54:18 -05:00
commit 4079ee1f26
35 changed files with 4683 additions and 4636 deletions

View File

@ -38,7 +38,7 @@ jobs:
- run: - run:
name: Run End-To-End Tests name: Run End-To-End Tests
command: | command: |
sudo pip install -q onnx sudo pip install -q -e ./ONNF/third_party/onnx
cd ONNF/build cd ONNF/build
cmake --build . --target run-onnx-backend-test cmake --build . --target run-onnx-backend-test
- run: - run:

View File

@ -1,2 +1,3 @@
BasedOnStyle: LLVM BasedOnStyle: LLVM
AlwaysBreakTemplateDeclarations: Yes AlwaysBreakTemplateDeclarations: Yes
AlignAfterOpenBracket: DontAlign

142
.gitignore vendored
View File

@ -30,3 +30,145 @@
*.exe *.exe
*.out *.out
*.app *.app
# Filesystem
.DS_Store
# The following .gitignore content is taken from
# https://github.com/github/gitignore/blob/master/Python.gitignore
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/

View File

@ -327,10 +327,10 @@ ONNX BatchNormalization operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values
1. `out_mean`: memref of any type values or tensor of any type values 1. `out_mean`: memref of any type values or tensor of any type values or none type
1. `out_var`: memref of any type values or tensor of any type values 1. `out_var`: memref of any type values or tensor of any type values or none type
1. `saved_mean`: memref of any type values or tensor of any type values 1. `saved_mean`: memref of any type values or tensor of any type values or none type
1. `saved_var`: memref of any type values or tensor of any type values 1. `saved_var`: memref of any type values or tensor of any type values or none type
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp) ### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
ONNX BatchNormalization operation in test mode ONNX BatchNormalization operation in test mode
@ -375,12 +375,12 @@ ONNX BitShift operation
"Bitwise shift operator performs element-wise operation. For each input element, if the" "Bitwise shift operator performs element-wise operation. For each input element, if the"
" attribute "direction" is "RIGHT", this operator moves its binary representation toward" " attribute \"direction\" is \"RIGHT\", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute "direction"" " the right side so that the input value is effectively decreased. If the attribute \"direction\""
" is "LEFT", bits of binary representation moves toward the left side, which results the" " is \"LEFT\", bits of binary representation moves toward the left side, which results the"
" increase of its actual value. The input X is the tensor to be shifted and another input" " increase of its actual value. The input X is the tensor to be shifted and another input"
" Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4]," " Y specifies the amounts of shifting. For example, if \"direction\" is \"Right\", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with" " and S is [1, 1], the corresponding output Z would be [0, 2]. If \"direction\" is \"LEFT\" with"
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]." " X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
" " " "
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are" " Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
@ -413,15 +413,15 @@ ONNX Cast operation
"the converted type. The 'to' argument must be one of the data types specified" "the converted type. The 'to' argument must be one of the data types specified"
"in the 'DataType' enum field in the TensorProto message." "in the 'DataType' enum field in the TensorProto message."
"" ""
"Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations" "Casting from string tensor in plain (e.g., \"3.14\" and \"1000\") and scientific numeric representations"
"(e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may" "(e.g., \"1e-5\" and \"1E8\") to float types is supported. For example, converting string \"100.5\" to an integer may"
"result 100. There are some string literals reserved for special floating-point values;" "result 100. There are some string literals reserved for special floating-point values;"
""+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively." "\"+INF\" (and \"INF\"), \"-INF\", and \"NaN\" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly," "Any string which can exactly match \"+INF\" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors" "this case-insensitive rule is applied to \"INF\" and \"NaN\". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as "314.15926") would be used. " "to string tensors, plain floating-point representation (such as \"314.15926\") would be used. "
"Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases " "Converting non-numerical-literal string such as \"Hello World!\" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior." "of converting string representing floating-point arithmetic value, such as \"2.718\", to INT is an undefined behavior."
"" ""
"Conversion from a numerical type to any numerical type is always allowed." "Conversion from a numerical type to any numerical type is always allowed."
"User must be aware of precision loss and value change caused by range difference between two types." "User must be aware of precision loss and value change caused by range difference between two types."
@ -476,8 +476,8 @@ ONNX Clip operation
#### Operands: #### Operands:
1. `input`: memref of any type values or tensor of any type values 1. `input`: memref of any type values or tensor of any type values
1. `min`: memref of any type values or tensor of any type values 1. `min`: memref of any type values or tensor of any type values or none type
1. `max`: memref of any type values or tensor of any type values 1. `max`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -618,8 +618,8 @@ ONNX ConvInteger operation
1. `x`: memref of any type values or tensor of any type values 1. `x`: memref of any type values or tensor of any type values
1. `w`: memref of any type values or tensor of any type values 1. `w`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values 1. `x_zero_point`: memref of any type values or tensor of any type values or none type
1. `w_zero_point`: memref of any type values or tensor of any type values 1. `w_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -678,7 +678,7 @@ ONNX Conv operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values 1. `W`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -720,7 +720,7 @@ ONNX ConvTranspose operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values 1. `W`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -884,7 +884,7 @@ ONNX DequantizeLinear operation
1. `x`: memref of any type values or tensor of any type values 1. `x`: memref of any type values or tensor of any type values
1. `x_scale`: memref of any type values or tensor of any type values 1. `x_scale`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values 1. `x_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -964,7 +964,7 @@ ONNX Dropout operation
#### Results: #### Results:
1. `output`: memref of any type values or tensor of any type values 1. `output`: memref of any type values or tensor of any type values
1. `mask`: memref of any type values or tensor of any type values 1. `mask`: memref of any type values or tensor of any type values or none type
### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp) ### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp)
ONNX DynamicQuantizeLinear operation ONNX DynamicQuantizeLinear operation
@ -1297,9 +1297,9 @@ ONNX GRU operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values 1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values 1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values 1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values 1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -1315,8 +1315,8 @@ ONNX GRU operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values 1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.GatherElements (ONNXGatherElementsOp) ### onnx.GatherElements (ONNXGatherElementsOp)
ONNX GatherElements operation ONNX GatherElements operation
@ -1558,33 +1558,6 @@ ONNX Gather operation
1. `output`: memref of any type values or tensor of any type values 1. `output`: memref of any type values or tensor of any type values
### onnx.GemmNoBias (ONNXGemmNoBiasOp)
ONNX general matrix multiply operation without bias.
#### Description:
The "onnx.Gemm" generic matrix multiplication without bias.
#### Operands:
1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
#### Attributes:
| Attribute | MLIR Type | Description |
| :-------: | :-------: | ----------- |
| `alpha` | `FloatAttr` | 32-bit float attribute attribute |
| `beta` | `FloatAttr` | 32-bit float attribute attribute |
| `transA` | `IntegerAttr` | 64-bit integer attribute attribute |
| `transB` | `IntegerAttr` | 64-bit integer attribute attribute |
#### Results:
1. `o_Y`: memref of any type values or tensor of any type values
### onnx.Gemm (ONNXGemmOp) ### onnx.Gemm (ONNXGemmOp)
ONNX Gemm operation ONNX Gemm operation
@ -1609,7 +1582,7 @@ ONNX Gemm operation
1. `A`: memref of any type values or tensor of any type values 1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values
1. `C`: memref of any type values or tensor of any type values 1. `C`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -2013,11 +1986,11 @@ ONNX LSTM operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values 1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values 1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values 1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values 1. `initial_h`: memref of any type values or tensor of any type values or none type
1. `initial_c`: memref of any type values or tensor of any type values 1. `initial_c`: memref of any type values or tensor of any type values or none type
1. `P`: memref of any type values or tensor of any type values 1. `P`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -2033,9 +2006,9 @@ ONNX LSTM operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values 1. `Y_h`: memref of any type values or tensor of any type values or none type
1. `Y_c`: memref of any type values or tensor of any type values 1. `Y_c`: memref of any type values or tensor of any type values or none type
### onnx.LeakyRelu (ONNXLeakyReluOp) ### onnx.LeakyRelu (ONNXLeakyReluOp)
ONNX LeakyRelu operation ONNX LeakyRelu operation
@ -2160,24 +2133,24 @@ ONNX Loop operation
"" ""
" Operator inputs defined as (max_trip_count, condition_var)." " Operator inputs defined as (max_trip_count, condition_var)."
"" ""
" input ("", ""):" " input (\"\", \"\"):"
" for (int i=0; ; ++i) {" " for (int i=0; ; ++i) {"
" cond = ... // Note this value is ignored, but is required in the body" " cond = ... // Note this value is ignored, but is required in the body"
" }" " }"
"" ""
" input ("", cond) // Note this is analogous to a while loop" " input (\"\", cond) // Note this is analogous to a while loop"
" bool cond = ...;" " bool cond = ...;"
" for (int i=0; cond; ++i) {" " for (int i=0; cond; ++i) {"
" cond = ...;" " cond = ...;"
" }" " }"
"" ""
" input ("", 1) // Note this is analogous to a do-while loop" " input (\"\", 1) // Note this is analogous to a do-while loop"
" bool cond = true" " bool cond = true"
" for (int i=0; cond; ++i) {" " for (int i=0; cond; ++i) {"
" cond = ...;" " cond = ...;"
" }" " }"
"" ""
" input (trip_count, "") // Note this is analogous to a for loop" " input (trip_count, \"\") // Note this is analogous to a for loop"
" int trip_count = ..." " int trip_count = ..."
" for (int i=0; i < trip_count; ++i) {" " for (int i=0; i < trip_count; ++i) {"
" cond = ...; // ignored" " cond = ...; // ignored"
@ -2203,15 +2176,15 @@ ONNX Loop operation
" }" " }"
"" ""
" graph body-net (" " graph body-net ("
" %i[INT32, scalar] // iteration number" " %i[INT32, scalar]"
" %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used" " %keepgoing[BOOL, scalar]"
" %b_in[INT32, scalar] // incoming value of loop-carried-dependency b" " %b[INT32, scalar]"
" ) {" " ) {"
" %my_local = Add(%a, %b_in)" " %my_local = Add(%a, %b)"
" %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b" " %b_out = Sub(%a, %b)"
" %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition" " %keepgoing_out = Greater(%my_local, %b_out)"
" %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated" " %user_defined_vals = Add(%b, %b)"
" return %keepgoing_out, %b_out, %user_defined_val" " return %keepgoing_out, %b_out, %user_defined_vals"
" }" " }"
"" ""
"*Sample equivalent C code*" "*Sample equivalent C code*"
@ -2226,51 +2199,31 @@ ONNX Loop operation
" const int max_trip_count = 10; // Analogous to input M" " const int max_trip_count = 10; // Analogous to input M"
" int user_defined_vals[]; // Imagine this is resizable" " int user_defined_vals[]; // Imagine this is resizable"
" /* End implicitly-defined code */" " /* End implicitly-defined code */"
" /* initialize loop-carried variables and scan-output variables */" " for (int i=0; i < max_trip_count && keepgoing; ++i) {"
" bool keepgoing_out = keepgoing"
" int b_out = b"
""
" for (int i=0; i < max_trip_count && keepgoing_out; ++i) {"
" /* Implicitly-defined code: bind actual parameter values"
" to formal parameter variables of loop-body */"
" bool keepgoing_in = keepgoing_out; "
" bool b_in = b_out;"
""
" /* User-defined code (loop body) */" " /* User-defined code (loop body) */"
" int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine" " int my_local = a + b; // Reading values in the enclosing scope is fine"
" b_out = a - b_in;" " b = a - b; // writes fine if we specify b as a loop-carried dependency"
" keepgoing_out = my_local > b_out; " " keepgoing = my_local > b; // keepgoing is a loop-carried dependency"
" user_defined_val = b_in + b_in; // b_in and b_out are different variables" " user_defined_vals[i] = b + b;"
" /* End user-defined code */" " /* End user-defined code */"
""
" /* Implicitly defined-code */"
" user_defined_vals[i] = user_defined_val // accumulate scan-output values"
" }" " }"
" // int t = my_local; // Can't do this. my_local is not accessible here." " // my_local = 123; // Can't do this. my_local was defined in the the body"
"" ""
" // The values below are bound to the output variables of the loop and therefore accessible" " // These below values are live-out from the loop and therefore accessible"
" // b_out; user_defined_vals; keepgoing_out;" " b_out; user_defined_vals; keepgoing_out;"
" }" " }"
"" ""
"There are several things of note in this code snippet:" "There are several things of note in this code snippet:"
"" ""
"1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can" "1) Values from the enclosing scope (i.e. variable a here) are in scope and can"
" be referenced in the inputs of the loop." " be referenced in the inputs of the loop."
"2) Any values computed in the loop body that needs to be used in a subsequent" "2) Any variables which you wish to make available in the enclosing scope (i.e."
" iteration or after the loop are modelled using a pair of variables in the loop-body," " the variables b and keepgoing) must be declared as either loop-carried"
" consisting of an input variable (eg., b_in) and an output variable (eg., b_out)." " dependencies (both at the op inputs and output and at the body net input and"
" These are referred to as loop-carried dependences. The loop operation node" " output) or scan_outputs."
" supplies the input value of the input variable for the first iteration, and" "3) Values created in the body cannot be accessed in the enclosing scope."
" returns the output value of the output variable produced by the final"
" iteration."
"3) Scan_output variables are used to implicitly concatenate values computed across"
" all the iterations. In the above example, the value of user_defined_val computed"
" over all iterations are concatenated and returned as the value of user_defined_vals"
" after the loop."
"4) Values created in the body cannot be accessed in the enclosing scope,"
" except using the mechanism described above."
"" ""
"Note that the semantics of this op support "diagonal" or "wavefront" execution." "Note that the semantics of this op support \"diagonal\" or \"wavefront\" execution."
"(See Step 3 here for an example:" "(See Step 3 here for an example:"
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)." "https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
"Frontends should emit multi-layer RNNs as a series of While operators (with" "Frontends should emit multi-layer RNNs as a series of While operators (with"
@ -2280,8 +2233,8 @@ ONNX Loop operation
#### Operands: #### Operands:
1. `M`: memref of any type values or tensor of any type values 1. `M`: memref of any type values or tensor of any type values or none type
1. `cond`: memref of any type values or tensor of any type values 1. `cond`: memref of any type values or tensor of any type values or none type
1. `v_initial`: memref of any type values or tensor of any type values 1. `v_initial`: memref of any type values or tensor of any type values
#### Attributes: #### Attributes:
@ -2360,8 +2313,8 @@ ONNX MatMulInteger operation
1. `A`: memref of any type values or tensor of any type values 1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values
1. `a_zero_point`: memref of any type values or tensor of any type values 1. `a_zero_point`: memref of any type values or tensor of any type values or none type
1. `b_zero_point`: memref of any type values or tensor of any type values 1. `b_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -2444,7 +2397,7 @@ ONNX MaxPool operation
" ```" " ```"
" pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]" " pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]"
" ```" " ```"
" The output of each pooling window is maximum number of elements exclude pad. " " The output of each pooling window is maximum number of elements exclude pad."
" " " "
#### Operands: #### Operands:
@ -2466,7 +2419,7 @@ ONNX MaxPool operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values
1. `Indices`: memref of any type values or tensor of any type values 1. `Indices`: memref of any type values or tensor of any type values or none type
### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp) ### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp)
ONNX MaxPool operation with a single output. ONNX MaxPool operation with a single output.
@ -2552,7 +2505,7 @@ ONNX MaxUnpool operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `I`: memref of any type values or tensor of any type values 1. `I`: memref of any type values or tensor of any type values
1. `output_shape`: memref of any type values or tensor of any type values 1. `output_shape`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -2752,9 +2705,9 @@ ONNX NonMaxSuppression operation
1. `boxes`: memref of any type values or tensor of any type values 1. `boxes`: memref of any type values or tensor of any type values
1. `scores`: memref of any type values or tensor of any type values 1. `scores`: memref of any type values or tensor of any type values
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values 1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values or none type
1. `iou_threshold`: memref of any type values or tensor of any type values 1. `iou_threshold`: memref of any type values or tensor of any type values or none type
1. `score_threshold`: memref of any type values or tensor of any type values 1. `score_threshold`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -3067,7 +3020,7 @@ ONNX Pad operation
1. `data`: memref of any type values or tensor of any type values 1. `data`: memref of any type values or tensor of any type values
1. `pads`: memref of any type values or tensor of any type values 1. `pads`: memref of any type values or tensor of any type values
1. `constant_value`: memref of any type values or tensor of any type values 1. `constant_value`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -3124,7 +3077,7 @@ ONNX QLinearConv operation
1. `w_zero_point`: memref of any type values or tensor of any type values 1. `w_zero_point`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values 1. `y_scale`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values 1. `y_zero_point`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -3188,7 +3141,7 @@ ONNX QuantizeLinear operation
1. `x`: memref of any type values or tensor of any type values 1. `x`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values 1. `y_scale`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values 1. `y_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -3270,9 +3223,9 @@ ONNX RNN operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values 1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values 1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values 1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values 1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -3287,8 +3240,8 @@ ONNX RNN operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values 1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.RandomNormalLike (ONNXRandomNormalLikeOp) ### onnx.RandomNormalLike (ONNXRandomNormalLikeOp)
ONNX RandomNormalLike operation ONNX RandomNormalLike operation
@ -3813,14 +3766,14 @@ ONNX Resize operation
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor." "Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
"Each dimension value of the output tensor is:" "Each dimension value of the output tensor is:"
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified." " output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified."
#### Operands: #### Operands:
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `roi`: memref of any type values or tensor of any type values 1. `roi`: memref of any type values or tensor of any type values
1. `scales`: memref of any type values or tensor of any type values 1. `scales`: memref of any type values or tensor of any type values
1. `sizes`: memref of any type values or tensor of any type values 1. `sizes`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -4438,7 +4391,7 @@ ONNX SequenceErase operation
#### Operands: #### Operands:
1. `input_sequence`: memref of any type values or tensor of any type values 1. `input_sequence`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values 1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -4463,7 +4416,7 @@ ONNX SequenceInsert operation
1. `input_sequence`: memref of any type values or tensor of any type values 1. `input_sequence`: memref of any type values or tensor of any type values
1. `tensor`: memref of any type values or tensor of any type values 1. `tensor`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values 1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -4680,8 +4633,8 @@ ONNX Slice operation
1. `data`: memref of any type values or tensor of any type values 1. `data`: memref of any type values or tensor of any type values
1. `starts`: memref of any type values or tensor of any type values 1. `starts`: memref of any type values or tensor of any type values
1. `ends`: memref of any type values or tensor of any type values 1. `ends`: memref of any type values or tensor of any type values
1. `axes`: memref of any type values or tensor of any type values 1. `axes`: memref of any type values or tensor of any type values or none type
1. `steps`: memref of any type values or tensor of any type values 1. `steps`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -4834,7 +4787,7 @@ ONNX SplitToSequence operation
#### Operands: #### Operands:
1. `input`: memref of any type values or tensor of any type values 1. `input`: memref of any type values or tensor of any type values
1. `split`: memref of any type values or tensor of any type values 1. `split`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -4902,9 +4855,9 @@ ONNX StringNormalizer operation
"StringNormalization performs string operations for basic cleaning." "StringNormalization performs string operations for basic cleaning."
"This operator has only one input (denoted by X) and only one output" "This operator has only one input (denoted by X) and only one output"
"(denoted by Y). This operator first examines the elements in the X," "(denoted by Y). This operator first examines the elements in the X,"
"and removes elements specified in "stopwords" attribute." "and removes elements specified in \"stopwords\" attribute."
"After removing stop words, the intermediate result can be further lowercased," "After removing stop words, the intermediate result can be further lowercased,"
"uppercased, or just returned depending the "case_change_action" attribute." "uppercased, or just returned depending the \"case_change_action\" attribute."
"This operator only accepts [C]- and [1, C]-tensor." "This operator only accepts [C]- and [1, C]-tensor."
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]" "If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
"if input shape is [C] and shape [1, 1] if input shape is [1, C]." "if input shape is [C] and shape [1, 1] if input shape is [1, C]."
@ -5034,8 +4987,8 @@ ONNX TfIdfVectorizer operation
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output." "respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
"Note that we may consider all skips up to S when generating the n-grams." "Note that we may consider all skips up to S when generating the n-grams."
"" ""
"The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and" "The examples used above are true if mode is \"TF\". If mode is \"IDF\", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF"," "the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is \"TFIDF\","
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute." "this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
"" ""
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor." "Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
@ -5123,9 +5076,9 @@ ONNX TopK operation
" contains the indices of the top k elements (original indices from the input" " contains the indices of the top k elements (original indices from the input"
" tensor)." " tensor)."
"" ""
"If "largest" is 1 (the default value) then the k largest elements are returned." "If \"largest\" is 1 (the default value) then the k largest elements are returned."
"If "sorted" is 1 (the default value) then the resulting k elements will be sorted." "If \"sorted\" is 1 (the default value) then the resulting k elements will be sorted."
"If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined." "If \"sorted\" is 0, order of returned 'Values' and 'Indices' are undefined."
"" ""
"Given two equivalent values, this operator uses the indices along the axis as" "Given two equivalent values, this operator uses the indices along the axis as"
" a tiebreaker. That is, the element with the lower index will appear first." " a tiebreaker. That is, the element with the lower index will appear first."
@ -5184,7 +5137,7 @@ ONNX Unique operation
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. " "This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
"The first output tensor 'Y' contains all unique values or subtensors of the input. " "The first output tensor 'Y' contains all unique values or subtensors of the input. "
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. " "The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". " "The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. \". "
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. " "The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
"" ""
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. " "Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
@ -5268,9 +5221,9 @@ ONNX Unique operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values
1. `indices`: memref of any type values or tensor of any type values 1. `indices`: memref of any type values or tensor of any type values or none type
1. `inverse_indices`: memref of any type values or tensor of any type values 1. `inverse_indices`: memref of any type values or tensor of any type values or none type
1. `counts`: memref of any type values or tensor of any type values 1. `counts`: memref of any type values or tensor of any type values or none type
### onnx.Unsqueeze (ONNXUnsqueezeOp) ### onnx.Unsqueeze (ONNXUnsqueezeOp)
ONNX Unsqueeze operation ONNX Unsqueeze operation

File diff suppressed because it is too large Load Diff

View File

@ -62,7 +62,21 @@ target_include_directories(onnf_shape_inference
target_link_libraries(onnf_shape_inference ${MLIRLibs}) target_link_libraries(onnf_shape_inference ${MLIRLibs})
add_dependencies(onnf_shape_inference gen_krnl_ops) add_dependencies(onnf_shape_inference gen_krnl_ops)
add_library(onnf_lower_frontend conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp) add_library(onnf_lower_frontend
conversion/onnx_to_krnl/onnx_to_krnl_common.cpp
conversion/onnx_to_krnl/onnx_to_krnl_common.hpp
conversion/onnx_to_krnl/math/elementwise.cpp
conversion/onnx_to_krnl/math/gemm.cpp
conversion/onnx_to_krnl/math/matmul.cpp
conversion/onnx_to_krnl/math/reduction.cpp
conversion/onnx_to_krnl/math/softmax.cpp
conversion/onnx_to_krnl/nn/conv.cpp
conversion/onnx_to_krnl/nn/normalization.cpp
conversion/onnx_to_krnl/tensor/identity.cpp
conversion/onnx_to_krnl/tensor/reshape.cpp
conversion/onnx_to_krnl/tensor/transpose.cpp
conversion/onnx_to_krnl/tensor/unsqueeze.cpp
conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
target_include_directories(onnf_lower_frontend target_include_directories(onnf_lower_frontend
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT}) ${ONNF_SRC_ROOT})

View File

@ -121,6 +121,7 @@ private:
mlir::MLIRContext &context_; mlir::MLIRContext &context_;
mlir::ModuleOp module_; mlir::ModuleOp module_;
mlir::OpBuilder builder_; mlir::OpBuilder builder_;
mlir::Value none_;
// mapping between string name and symbol // mapping between string name and symbol
OnnxOnnfSymbolMapping frontend_symbols_; OnnxOnnfSymbolMapping frontend_symbols_;
@ -188,8 +189,9 @@ private:
} }
} }
mlir::Type elementType = auto elementOnnxType =
convertONNXTypeToMLIRType(input.type().tensor_type().elem_type()); (onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size()); llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
arg_types.emplace_back( arg_types.emplace_back(
mlir::RankedTensorType::get(tensor_dims, elementType)); mlir::RankedTensorType::get(tensor_dims, elementType));
@ -287,8 +289,8 @@ private:
} }
} }
std::vector<mlir::NamedAttribute> ImportNodeAttributes( std::vector<mlir::NamedAttribute>
const onnx::NodeProto &node) { ImportNodeAttributes(const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute> attributes; std::vector<mlir::NamedAttribute> attributes;
for (int i = 0; i < node.attribute_size(); ++i) { for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i); auto attr = node.attribute(i);
@ -317,21 +319,11 @@ private:
} }
} }
// if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
// combined with 'if constexpr' the issue is the type of the output is
// different. alternative way to use variadic output for all the op
/*!
* Important onnx node which generates only one output
* @param node onnx node
* @param nIn number of expected inputs
* @param nOut number of expected outputs
* @param attrs list of desription for attributes with format {name, type,
* default}
*/
template <typename T> template <typename T>
void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut, void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
bool variadicIn = false, bool variadicOut = false) { int expectedNumResults = -1) {
bool variadicIn = expectedNumOperands == -1;
bool variadicOut = expectedNumResults == -1;
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
@ -339,6 +331,10 @@ private:
} }
} }
if (!variadicIn)
for (auto i = inputs.size(); i < expectedNumOperands; i++)
inputs.emplace_back(none_);
std::vector<mlir::Type> outputTypes; std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) { for (auto item : node.output()) {
outputTypes.push_back( outputTypes.push_back(
@ -347,49 +343,11 @@ private:
auto attributes = ImportNodeAttributes(node); auto attributes = ImportNodeAttributes(node);
llvm::StringRef OpName = node.op_type(); // TODO: Handle optional inputs.
if ((variadicIn || nIn == inputs.size()) && auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
(variadicOut || nOut == outputTypes.size())) { for (int i = 0; i < node.output().size(); i++) {
auto op = frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); *(op.getODSResults(i).begin()));
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
op.getResult());
} else {
ImportNodeGeneric(node);
}
}
template <typename T>
void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut,
bool variadicIn = false,
bool variadicOut = false) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
auto attributes = ImportNodeAttributes(node);
llvm::StringRef OpName = node.op_type();
if ((variadicIn || nIn == inputs.size()) &&
(variadicOut || nOut == outputTypes.size())) {
auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
op.getResult(i));
}
} else {
ImportNodeGeneric(node);
} }
} }
@ -398,8 +356,7 @@ private:
* c++ does not allow template specialization inside a class scope * c++ does not allow template specialization inside a class scope
* a specialized function is used * a specialized function is used
*/ */
void void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
// Conv has attribute dilations, kernel_shape, pads, the default value of // Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the // which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto // shape is unknown now, these attributes can be not generated auto
@ -413,24 +370,20 @@ private:
int nOps = node.input().size(); int nOps = node.input().size();
if (nOps == 2) if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>( buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut);
node, nOps, nOut);
else else
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut); buildOperation<mlir::ONNXConvOp>(node, nOps, nOut);
} }
/*! /*!
* Special handle for MaxPool operations. * Special handle for MaxPool operations.
*/ */
void ImportNodeMaxPool( void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) {
onnx::NodeProto node, int nIn, int nOut) {
int nOuts = node.output().size(); int nOuts = node.output().size();
if (nOuts == 1) { if (nOuts == 1) {
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>( buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts);
node, nIn, nOuts);
} else { } else {
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>( buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts);
node, nIn, nOuts);
} }
} }
@ -441,23 +394,10 @@ private:
int nOuts = node.output().size(); int nOuts = node.output().size();
if (nOuts == 1) { if (nOuts == 1) {
// Test mode with one output. // Test mode with one output.
ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts);
nOuts);
} else { } else {
// Training mode with four trailing optional outputs. Not handled yet. // Training mode with four trailing optional outputs. Not handled yet.
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts); buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
}
}
/*!
* Special handle for Gemm operations.
*/
void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXGemmNoBiasOp>(node, 2, nOut);
} else {
ImportNodeOneOut<mlir::ONNXGemmOp>(node, nIn, nOut);
} }
} }
@ -467,28 +407,14 @@ private:
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) { void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size(); int nOps = node.input().size();
if (nOps == 2) { if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXPadConstantValueOp>(node, 2, nOut); buildOperation<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
} else { } else {
ImportNodeOneOut<mlir::ONNXPadOp>(node, nIn, nOut); buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
} }
} }
void ImportNode(const onnx::NodeProto &node) { void ImportNode(const onnx::NodeProto &node) {
std::vector<mlir::Value> inputs; llvm::StringRef opName = node.op_type();
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
std::vector<mlir::NamedAttribute> attributes;
llvm::StringRef OpName = node.op_type();
// the following code is generated by gen_doc.py // the following code is generated by gen_doc.py
// refer to dialect/onnx/onnx.td for details // refer to dialect/onnx/onnx.td for details
@ -555,9 +481,11 @@ private:
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it)); ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
} }
// import nodes in the graph // Create a NoneTyped constant.
auto node = graph.node(); none_ =
for (const auto &item : node) { builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
// Import nodes in the graph.
for (const auto &item : graph.node()) {
ImportNode(item); ImportNode(item);
} }

View File

@ -1,320 +1,319 @@
//******************************************************** //********************************************************
// Warning: Do not modify this file directly // This file is generated on UTC-02/24/2020, 06:29:01.
// This file is automatically generated via script // Do not modify this file directly.
// Details can be found in doc/readonnxdefs.md // This file is automatically generated via script.
// Details can be found in doc/readonnxdefs.md .
//******************************************************** //********************************************************
if (OpName == "DUMMY") { if (opName == "Abs")
}else if (OpName == "Abs") { return buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1); if (opName == "Acos")
}else if (OpName == "Acos") { return buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1); if (opName == "Acosh")
}else if (OpName == "Acosh") { return buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1); if (opName == "Add")
}else if (OpName == "Add") { return buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1); if (opName == "And")
}else if (OpName == "And") { return buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1); if (opName == "ArgMax")
}else if (OpName == "ArgMax") { return buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1); if (opName == "ArgMin")
}else if (OpName == "ArgMin") { return buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1); if (opName == "Asin")
}else if (OpName == "Asin") { return buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1); if (opName == "Asinh")
}else if (OpName == "Asinh") { return buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1); if (opName == "Atan")
}else if (OpName == "Atan") { return buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1); if (opName == "Atanh")
}else if (OpName == "Atanh") { return buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1); if (opName == "AveragePool")
}else if (OpName == "AveragePool") { return buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1); if (opName == "BatchNormalization")
}else if (OpName == "BatchNormalization") { return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
ImportNodeBatchNormalization(node, 5, 5); if (opName == "BitShift")
}else if (OpName == "BitShift") { return buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1); if (opName == "Cast")
}else if (OpName == "Cast") { return buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1); if (opName == "Ceil")
}else if (OpName == "Ceil") { return buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1); if (opName == "Clip")
}else if (OpName == "Clip") { return buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1); if (opName == "Compress")
}else if (OpName == "Compress") { return buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1); if (opName == "Concat")
}else if (OpName == "Concat") { return buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, true, false); if (opName == "ConcatFromSequence")
}else if (OpName == "ConcatFromSequence") { return buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1); if (opName == "Constant")
}else if (OpName == "Constant") { return buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1); if (opName == "ConstantOfShape")
}else if (OpName == "ConstantOfShape") { return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1); if (opName == "Conv")
}else if (OpName == "Conv") { return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeConv(node, 3, 1); if (opName == "ConvInteger")
}else if (OpName == "ConvInteger") { return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1); if (opName == "ConvTranspose")
}else if (OpName == "ConvTranspose") { return buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1); if (opName == "Cos")
}else if (OpName == "Cos") { return buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1); if (opName == "Cosh")
}else if (OpName == "Cosh") { return buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1); if (opName == "CumSum")
}else if (OpName == "CumSum") { return buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1); if (opName == "DepthToSpace")
}else if (OpName == "DepthToSpace") { return buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1); if (opName == "DequantizeLinear")
}else if (OpName == "DequantizeLinear") { return buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1); if (opName == "Det")
}else if (OpName == "Det") { return buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1); if (opName == "Div")
}else if (OpName == "Div") { return buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1); if (opName == "Dropout")
}else if (OpName == "Dropout") { return buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2); if (opName == "DynamicQuantizeLinear")
}else if (OpName == "DynamicQuantizeLinear") { return buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3); if (opName == "Elu")
}else if (OpName == "Elu") { return buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1); if (opName == "Equal")
}else if (OpName == "Equal") { return buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1); if (opName == "Erf")
}else if (OpName == "Erf") { return buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1); if (opName == "Exp")
}else if (OpName == "Exp") { return buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1); if (opName == "Expand")
}else if (OpName == "Expand") { return buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1); if (opName == "EyeLike")
}else if (OpName == "EyeLike") { return buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1); if (opName == "Flatten")
}else if (OpName == "Flatten") { return buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1); if (opName == "Floor")
}else if (OpName == "Floor") { return buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1); if (opName == "GRU")
}else if (OpName == "GRU") { return buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2); if (opName == "Gather")
}else if (OpName == "Gather") { return buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1); if (opName == "GatherElements")
}else if (OpName == "GatherElements") { return buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1); if (opName == "GatherND")
}else if (OpName == "GatherND") { return buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1); if (opName == "Gemm")
}else if (OpName == "Gemm") { return buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeGemm(node, 3, 1); if (opName == "GlobalAveragePool")
}else if (OpName == "GlobalAveragePool") { return buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1); if (opName == "GlobalLpPool")
}else if (OpName == "GlobalLpPool") { return buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1); if (opName == "GlobalMaxPool")
}else if (OpName == "GlobalMaxPool") { return buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1); if (opName == "Greater")
}else if (OpName == "Greater") { return buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1); if (opName == "HardSigmoid")
}else if (OpName == "HardSigmoid") { return buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1); if (opName == "Hardmax")
}else if (OpName == "Hardmax") { return buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1); if (opName == "Identity")
}else if (OpName == "Identity") { return buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1); if (opName == "If")
}else if (OpName == "If") { return buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1); if (opName == "InstanceNormalization")
}else if (OpName == "InstanceNormalization") { return buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1); if (opName == "IsInf")
}else if (OpName == "IsInf") { return buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1); if (opName == "IsNaN")
}else if (OpName == "IsNaN") { return buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1); if (opName == "LRN")
}else if (OpName == "LRN") { return buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1); if (opName == "LSTM")
}else if (OpName == "LSTM") { return buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3); if (opName == "LeakyRelu")
}else if (OpName == "LeakyRelu") { return buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1); if (opName == "Less")
}else if (OpName == "Less") { return buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1); if (opName == "Log")
}else if (OpName == "Log") { return buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1); if (opName == "LogSoftmax")
}else if (OpName == "LogSoftmax") { return buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1); if (opName == "Loop")
}else if (OpName == "Loop") { return buildOperation<mlir::ONNXLoopOp>(node);
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1); if (opName == "LpNormalization")
}else if (OpName == "LpNormalization") { return buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1); if (opName == "LpPool")
}else if (OpName == "LpPool") { return buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1); if (opName == "MatMul")
}else if (OpName == "MatMul") { return buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1); if (opName == "MatMulInteger")
}else if (OpName == "MatMulInteger") { return buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1); if (opName == "Max")
}else if (OpName == "Max") { return buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, true, false); if (opName == "MaxPool")
}else if (OpName == "MaxPool") { return ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
ImportNodeMaxPool(node, 1, 2); if (opName == "MaxRoiPool")
}else if (OpName == "MaxRoiPool") { return buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1); if (opName == "MaxUnpool")
}else if (OpName == "MaxUnpool") { return buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1); if (opName == "Mean")
}else if (OpName == "Mean") { return buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, true, false); if (opName == "MeanVarianceNormalization")
}else if (OpName == "MeanVarianceNormalization") { return buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1); if (opName == "Min")
}else if (OpName == "Min") { return buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, true, false); if (opName == "Mod")
}else if (OpName == "Mod") { return buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1); if (opName == "Mul")
}else if (OpName == "Mul") { return buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1); if (opName == "Multinomial")
}else if (OpName == "Multinomial") { return buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1); if (opName == "Neg")
}else if (OpName == "Neg") { return buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1); if (opName == "NonMaxSuppression")
}else if (OpName == "NonMaxSuppression") { return buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1); if (opName == "NonZero")
}else if (OpName == "NonZero") { return buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1); if (opName == "Not")
}else if (OpName == "Not") { return buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1); if (opName == "OneHot")
}else if (OpName == "OneHot") { return buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1); if (opName == "Or")
}else if (OpName == "Or") { return buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1); if (opName == "PRelu")
}else if (OpName == "PRelu") { return buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1); if (opName == "Pad")
}else if (OpName == "Pad") { return ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodePad(node, 3, 1); if (opName == "Pow")
}else if (OpName == "Pow") { return buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1); if (opName == "QLinearConv")
}else if (OpName == "QLinearConv") { return buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1); if (opName == "QLinearMatMul")
}else if (OpName == "QLinearMatMul") { return buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1); if (opName == "QuantizeLinear")
}else if (OpName == "QuantizeLinear") { return buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1); if (opName == "RNN")
}else if (OpName == "RNN") { return buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2); if (opName == "RandomNormal")
}else if (OpName == "RandomNormal") { return buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1); if (opName == "RandomNormalLike")
}else if (OpName == "RandomNormalLike") { return buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1); if (opName == "RandomUniform")
}else if (OpName == "RandomUniform") { return buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1); if (opName == "RandomUniformLike")
}else if (OpName == "RandomUniformLike") { return buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1); if (opName == "Range")
}else if (OpName == "Range") { return buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1); if (opName == "Reciprocal")
}else if (OpName == "Reciprocal") { return buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1); if (opName == "ReduceL1")
}else if (OpName == "ReduceL1") { return buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1); if (opName == "ReduceL2")
}else if (OpName == "ReduceL2") { return buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1); if (opName == "ReduceLogSum")
}else if (OpName == "ReduceLogSum") { return buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1); if (opName == "ReduceLogSumExp")
}else if (OpName == "ReduceLogSumExp") { return buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1); if (opName == "ReduceMax")
}else if (OpName == "ReduceMax") { return buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1); if (opName == "ReduceMean")
}else if (OpName == "ReduceMean") { return buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1); if (opName == "ReduceMin")
}else if (OpName == "ReduceMin") { return buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1); if (opName == "ReduceProd")
}else if (OpName == "ReduceProd") { return buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1); if (opName == "ReduceSum")
}else if (OpName == "ReduceSum") { return buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1); if (opName == "ReduceSumSquare")
}else if (OpName == "ReduceSumSquare") { return buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1); if (opName == "Relu")
}else if (OpName == "Relu") { return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1); if (opName == "Reshape")
}else if (OpName == "Reshape") { return buildOperation<mlir::ONNXReshapeOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1); if (opName == "Resize")
}else if (OpName == "Resize") { return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1); if (opName == "ReverseSequence")
}else if (OpName == "ReverseSequence") { return buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1); if (opName == "RoiAlign")
}else if (OpName == "RoiAlign") { return buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1); if (opName == "Round")
}else if (OpName == "Round") { return buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1); if (opName == "Scan")
}else if (OpName == "Scan") { return buildOperation<mlir::ONNXScanOp>(node);
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1); if (opName == "Scatter")
}else if (OpName == "Scatter") { return buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1); if (opName == "ScatterElements")
}else if (OpName == "ScatterElements") { return buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1); if (opName == "ScatterND")
}else if (OpName == "ScatterND") { return buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1); if (opName == "Selu")
}else if (OpName == "Selu") { return buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1); if (opName == "SequenceAt")
}else if (OpName == "SequenceAt") { return buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1); if (opName == "SequenceConstruct")
}else if (OpName == "SequenceConstruct") { return buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, true, false); if (opName == "SequenceEmpty")
}else if (OpName == "SequenceEmpty") { return buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1); if (opName == "SequenceErase")
}else if (OpName == "SequenceErase") { return buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1); if (opName == "SequenceInsert")
}else if (OpName == "SequenceInsert") { return buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1); if (opName == "SequenceLength")
}else if (OpName == "SequenceLength") { return buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1); if (opName == "Shape")
}else if (OpName == "Shape") { return buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1); if (opName == "Shrink")
}else if (OpName == "Shrink") { return buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1); if (opName == "Sigmoid")
}else if (OpName == "Sigmoid") { return buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1); if (opName == "Sign")
}else if (OpName == "Sign") { return buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1); if (opName == "Sin")
}else if (OpName == "Sin") { return buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1); if (opName == "Sinh")
}else if (OpName == "Sinh") { return buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1); if (opName == "Size")
}else if (OpName == "Size") { return buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1); if (opName == "Slice")
}else if (OpName == "Slice") { return buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1); if (opName == "Softmax")
}else if (OpName == "Softmax") { return buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1); if (opName == "Softplus")
}else if (OpName == "Softplus") { return buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1); if (opName == "Softsign")
}else if (OpName == "Softsign") { return buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1); if (opName == "SpaceToDepth")
}else if (OpName == "SpaceToDepth") { return buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1); if (opName == "Split")
}else if (OpName == "Split") { return buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1); if (opName == "SplitToSequence")
}else if (OpName == "SplitToSequence") { return buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1); if (opName == "Sqrt")
}else if (OpName == "Sqrt") { return buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1); if (opName == "Squeeze")
}else if (OpName == "Squeeze") { return buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1); if (opName == "StringNormalizer")
}else if (OpName == "StringNormalizer") { return buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1); if (opName == "Sub")
}else if (OpName == "Sub") { return buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1); if (opName == "Sum")
}else if (OpName == "Sum") { return buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, true, false); if (opName == "Tan")
}else if (OpName == "Tan") { return buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1); if (opName == "Tanh")
}else if (OpName == "Tanh") { return buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1); if (opName == "TfIdfVectorizer")
}else if (OpName == "TfIdfVectorizer") { return buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1); if (opName == "ThresholdedRelu")
}else if (OpName == "ThresholdedRelu") { return buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1); if (opName == "Tile")
}else if (OpName == "Tile") { return buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1); if (opName == "TopK")
}else if (OpName == "TopK") { return buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2); if (opName == "Transpose")
}else if (OpName == "Transpose") { return buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1); if (opName == "Unique")
}else if (OpName == "Unique") { return buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4); if (opName == "Unsqueeze")
}else if (OpName == "Unsqueeze") { return buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1); if (opName == "Upsample")
}else if (OpName == "Upsample") { return buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1); if (opName == "Where")
}else if (OpName == "Where") { return buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1); if (opName == "Xor")
}else if (OpName == "Xor") { return buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1);
}

View File

@ -8,404 +8,11 @@
// Krnl IR and standard operations. // Krnl IR and standard operations.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <map>
#include "mlir/Dialect/AffineOps/AffineOps.h" #include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "src/dialect/krnl/krnl_helper.hpp"
#include "src/dialect/krnl/krnl_ops.hpp"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "src/pass/passes.hpp"
using namespace mlir; using namespace mlir;
//===----------------------------------------------------------------------===//
// FrontendToAffine RewritePatterns
//===----------------------------------------------------------------------===//
/// Check is all dimensions are known at compile time.
static bool hasAllConstantDimensions(MemRefType type) {
auto memRefShape = type.getShape();
for (int i = 0; i < memRefShape.size(); ++i)
if (memRefShape[i] < 0)
return false;
return true;
}
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
static MemRefType convertToMemRefType(Type type) {
MemRefType memRefType;
auto tensorType = type.dyn_cast<TensorType>();
if (tensorType) {
assert(tensorType.hasRank() && "expected only ranked shapes");
memRefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
} else {
memRefType = type.dyn_cast<MemRefType>();
}
return memRefType;
}
/// Insert an allocation and deallocation for the given MemRefType.
static Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter,
bool insertDealloc,
ArrayRef<Value> operands = {}) {
// Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc;
if (!operands.empty()) {
auto memRefShape = type.getShape();
auto rank = memRefShape.size();
std::map<int, Value> fromOperands;
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
int memRefDimIdx = rank - 1 - reversedIdx;
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
Value maxDim = nullptr;
for (int i = 0; i < operands.size(); i++) {
auto operandShape =
operands[i].getType().cast<MemRefType>().getShape();
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
if (operandDimIdx < 0)
continue;
// In case of operations with broadcasting, the dimension of the
// alloc result is the maximum size along each dimension of the
// operands.
auto operandDim =
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
if (maxDim) {
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
operandDim, maxDim);
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
maxDim);
} else {
maxDim = operandDim;
}
}
fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
}
}
SmallVector<Value, 4> allocOperands;
for (int i = 0; i < rank; ++i)
if (memRefShape[i] < 0)
allocOperands.push_back(fromOperands[i]);
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
} else {
alloc = rewriter.create<AllocOp>(loc, type);
}
// Make sure to allocate at the beginning of the block if
// all dimensions are known.
auto *parentBlock = alloc.getOperation()->getBlock();
if (hasAllConstantDimensions(type))
alloc.getOperation()->moveBefore(&parentBlock->front());
if (insertDealloc) {
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
return alloc;
}
// Determine if current function returns the result value of the
// current op being lowered. If it does then dealloc should not be
// inserted.
static bool checkInsertDealloc(Operation *currentOp) {
auto parentBlock = currentOp->getBlock();
bool insertDealloc = true;
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
assert(currentOp->getNumResults() < 2 &&
"No more than one result supported (for now).");
// If there is at least one result to investigate.
if (currentOp->getNumResults() > 0) {
auto result = currentOp->getResult(0);
for (const auto &operand : op.getOperands())
if (operand == result)
insertDealloc = false;
}
});
return insertDealloc;
}
// Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input
// type.
std::map<int64_t, int64_t>
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
std::map<int64_t, int64_t> OutInDimMap;
int64_t rank = inputTy.getRank();
// Mark reduction axes.
std::vector<bool> isReductionAxis;
for (decltype(rank) i = 0; i < rank; ++i) {
if (std::find(axes.begin(), axes.end(), i) != axes.end())
isReductionAxis.push_back(true);
else
isReductionAxis.push_back(false);
}
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
// If it is a reduction axis, there is no relationship among dimensions.
if (isReductionAxis[inIndex]) {
if (keepdims)
outIndex++;
} else {
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
outIndex++;
}
}
return OutInDimMap;
}
// Add bounds associated with the op operand to the KRNL iteration pack.
// Dynamic dimenions are supported.
static void addDimensionToPack(ConversionPatternRewriter &rewriter,
Location loc, KrnlIterateOperandPack &pack,
Value operand, int index) {
auto shape = operand.getType().cast<MemRefType>().getShape();
if (shape[index] < 0) {
pack.pushConstantBound(0);
pack.pushOperandBound(
rewriter.create<DimOp>(loc, operand, index).getResult());
} else {
pack.pushConstantBound(0);
pack.pushConstantBound(shape[index]);
}
}
// Function that defines the KRNL dialect loops and their respective
// optimized version.
static KrnlOptimizeLoopsOp
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops, int64_t numLoops) {
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
loops.reserve(numLoops);
for (auto result : loopsOp.getResults())
loops.push_back(result);
// Define optimized version of the loops.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
optimizedLoops.reserve(numLoops);
for (auto result : optimizedLoopsOp.getResults())
optimizedLoops.push_back(result);
return optimizedLoopsOp;
}
// Function that emits the loops and their optimized version.
// The function returns a reference to the inner optimization block.
static Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops,
int64_t numLoops) {
KrnlOptimizeLoopsOp optimizedLoopsOp =
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
return &optimizedLoopsOp.region().front();
}
// Function which emits a basic set of loops and optimized loops
// for a given operation argument. A reference to the loop optimization
// block is returned in the last argument of the function.
static void emitKrnlLoopsAndIterationForOperand(
ConversionPatternRewriter &rewriter, Location loc, Value operand,
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
KrnlIterateOp &iterateOp) {
// Operand shape.
auto shape = operand.getType().cast<MemRefType>().getShape();
// Number of loops.
int64_t rank = shape.size();
// Define loops and optimized loops.
std::vector<Value> optimizedLoops;
optimizedLoopsOp =
emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest.
for (int i = 0; i < rank; ++i)
addDimensionToPack(rewriter, loc, pack, operand, i);
iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
}
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
auto elementType = memRefType.getElementType();
unsigned sizeInBits;
if (elementType.isIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
} else {
auto vectorType = elementType.cast<VectorType>();
sizeInBits =
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
}
return llvm::divideCeil(sizeInBits, 8);
}
// Get run-time dimension information for unknown dimensions used for
// broadcasting.
std::map<int, std::map<int, Value>>
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
MemRefType memRefType, ArrayRef<Value> operands) {
auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size();
// For unknown dimensions, we need to get dimension values at runtime in
// order to do broadcasting.
std::map<int, std::map<int, Value>> DimInfo;
// For each result dimension, compute the number of sharing operands.
// Sharing operands are operands sharing the same index (counting from the
// rightmost to the leftmost) for a given dimension.
std::map<int, int> sharedDimCount;
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
int dimIdx = rank - 1 - reversedIdx;
sharedDimCount[dimIdx] = 0;
for (int i = 0; i < operands.size(); ++i) {
auto shape = operands[i].getType().cast<MemRefType>().getShape();
if (reversedIdx <= shape.size() - 1)
sharedDimCount[dimIdx]++;
}
}
// An unknown dimension can have a value of 1 or N (N > 1).
// If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension.
// We only care about unknown dimensions whose number of sharing operands is
// more than one, since they are potentially broadcasted dimensions.
for (int i = 0; i < operands.size(); ++i) {
std::map<int, Value> broadcastedDims;
auto shape = operands[i].getType().cast<MemRefType>().getShape();
int size = shape.size();
for (int j = 0; j < shape.size(); ++j) {
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
auto isBroadcasted =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
}
}
DimInfo.insert(std::make_pair(i, broadcastedDims));
}
return DimInfo;
}
// Extract induction variables that are used for broadcasting values of a
// given operand.
std::vector<Value>
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value> broadcastedDims) {
// `operand` must has a ranked type. This should have been checked by the
// shape inference pass.
auto operandShape = operand.getType().cast<MemRefType>().getShape();
auto rank = operandShape.size();
auto loopCount = loopIVs.size();
std::vector<Value> newLoopIVs;
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
auto dimIdx = rank - 1 - reversedIdx;
auto loopIdx = loopCount - 1 - reversedIdx;
if (operandShape[dimIdx] == 1) {
// Broadcasted dimension
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
newLoopIVs.insert(newLoopIVs.begin(), zero);
} else if ((operandShape[dimIdx] == -1) &&
(broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
// Unknown dimension, it can have a value of 1 or N (N > 1).
// If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension.
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
loopIVs[loopIdx]);
newLoopIVs.insert(newLoopIVs.begin(), idx);
} else {
// Non-broadcasted dimension
newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
}
}
return newLoopIVs;
}
namespace {
// This is to get a scalar operation of a given type for a specific operation.
template <typename Op>
struct ScalarOp {
using FOp = void;
using IOp = void;
};
template <typename FOp>
using ScalarFOp = typename ScalarOp<FOp>::FOp;
template <typename IOp>
using ScalarIOp = typename ScalarOp<IOp>::IOp;
// Get the identity element of a operation.
// Return NULL if the function does not have identity.
template <typename DataType, typename Op>
DataType getIdentityValue() {
return NULL;
}
//===----------------------------------------------------------------------===//
// This is used in the innermost loop of a KrnlIterateOp to insert computation
// composed of one or many scalar ops.
// Use template specialization for each of different ONNX operations.
//===----------------------------------------------------------------------===//
template <typename Op>
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();
Type element_type = operands.front().getType();
if (element_type.isa<IntegerType>()) {
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
mlir::None);
} else if (element_type.isa<FloatType>()) {
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
mlir::None);
} else {
emitError(loc, "unsupported element type");
return nullptr;
}
}
// We divide the operator lowering into different categories.
// These categories are mostly similar to the operator categories in ONNX:
// https://github.com/onnx/onnx/tree/master/onnx/defs.
// Besides, it is better to put operators with the same computation pattern into
// the same category, e.g. element-wise operators will belong to the elementwise
// category.
// Math
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc"
// Tensor
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
// Neural network
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// EntryPoint Op lowering to Krnl Entry Point. // EntryPoint Op lowering to Krnl Entry Point.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -427,39 +34,6 @@ public:
} }
}; };
//===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type.
//===----------------------------------------------------------------------===//
struct TensorTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
TensorTypeConverter() {
addConversion(convertType);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
if (auto type = convertToMemRefType(t)) {
results.push_back(type);
return success();
}
results.push_back(t);
return success();
}
/// Return true if the inputs and outputs of the given function type are
/// legal. [Taken from MLIR and adapted to only check the legality of the
/// inputs. Once unranked results can be handled gracefully this
/// override needs to be removed in favour of the original MLIR one.]
bool isSignatureLegal(FunctionType funcType) {
return llvm::all_of(funcType.getInputs(),
[this](Type type) { return isLegal(type); });
}
};
} // end anonymous namespace.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Frontend to Krnl Dialect lowering pass // Frontend to Krnl Dialect lowering pass
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1,4 +1,4 @@
//===----- elementwise.inc - Elementwise Ops ------------------------------===// //===----- elementwise.cpp - Elementwise Ops ------------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -8,6 +8,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
template <> template <>
struct ScalarOp<ONNXAddOp> { struct ScalarOp<ONNXAddOp> {
using FOp = AddFOp; using FOp = AddFOp;

View File

@ -1,4 +1,4 @@
//===----- gemm.inc - Lowering Gemm Op ------------------------------------===// //===----- gemm.cpp - Lowering Gemm Op ------------------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -8,6 +8,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
template <typename GemmOp> template <typename GemmOp>
struct ONNXGemmOpLowering : public ConversionPattern { struct ONNXGemmOpLowering : public ConversionPattern {
ONNXGemmOpLowering(MLIRContext *ctx) ONNXGemmOpLowering(MLIRContext *ctx)
@ -17,20 +21,22 @@ struct ONNXGemmOpLowering : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
auto has_bias = (operands.size() == 3); bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
Value A, B, C; Value A, B, C;
A = operands[0]; A = operands[0];
B = operands[1]; B = operands[1];
if (has_bias) if (hasBias)
C = operands[2]; C = operands[2];
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
auto alphaAttr = FloatAttr::get(memRefType.getElementType(), auto alphaAttr =
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat()); FloatAttr::get(memRefType.getElementType(),
auto betaAttr = FloatAttr::get(memRefType.getElementType(), llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat()); auto betaAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr); auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -68,8 +74,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// Define loops. // Define loops.
std::vector<Value> originalLoops; std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops; std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, Block *optimizationBlock =
optimizedLoops, numLoops); defineLoops(rewriter, loc, originalLoops, optimizedLoops, numLoops);
// We have two Krnl loops: // We have two Krnl loops:
// - Outer loop iterates over the output matrix dimensions, and // - Outer loop iterates over the output matrix dimensions, and
@ -83,8 +89,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
outerLoops.push_back(originalLoops[i]); outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]); optimizedOuterLoops.push_back(optimizedLoops[i]);
} }
KrnlIterateOperandPack outerPack(rewriter, outerLoops, KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
optimizedOuterLoops);
// Induction variables for the outer loops // Induction variables for the outer loops
for (int i = 0; i < 2; ++i) for (int i = 0; i < 2; ++i)
addDimensionToPack(rewriter, loc, outerPack, alloc, i); addDimensionToPack(rewriter, loc, outerPack, alloc, i);
@ -106,20 +111,19 @@ struct ONNXGemmOpLowering : public ConversionPattern {
int64_t K_B_Idx = (isTransB) ? 1 : 0; int64_t K_B_Idx = (isTransB) ? 1 : 0;
reductionPack.pushConstantBound(0); reductionPack.pushConstantBound(0);
if (ATy.getShape()[K_A_Idx] != -1) if (ATy.getShape()[K_A_Idx] != -1)
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]); reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
else if (BTy.getShape()[K_B_Idx] != -1)
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
else else
if (BTy.getShape()[K_B_Idx] != -1) reductionPack.pushOperandBound(
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]); rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
else
reductionPack.pushOperandBound(
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
// Get run-time dimension information for unknown dimensions used for // Get run-time dimension information for unknown dimensions used for
// broadcasting. // broadcasting.
// GemmOp supports unidirectional broadcasting from C to A*B. // GemmOp supports unidirectional broadcasting from C to A*B.
// Hence, it must be enough to get broadcasting information for C only. // Hence, it must be enough to get broadcasting information for C only.
std::map<int, Value> broadcastedDimInfo; std::map<int, Value> broadcastedDimInfo;
if (has_bias) { if (hasBias) {
auto shape = C.getType().cast<MemRefType>().getShape(); auto shape = C.getType().cast<MemRefType>().getShape();
for (int i = 0; i < shape.size(); ++i) { for (int i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) { if (shape[i] < 0) {
@ -162,7 +166,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting) // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs); auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB); auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
if (has_bias) { if (hasBias) {
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C, auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
broadcastedDimInfo); broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs); auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
@ -210,8 +214,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
} }
}; };
void populateLoweringONNXGemmOpPattern( void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
OwningRewritePatternList &patterns, MLIRContext *ctx) { MLIRContext *ctx) {
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx); patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
} }

View File

@ -1,4 +1,4 @@
//===----- matmul.inc - Lowering Matmul Op --------------------------------===// //===----- matmul.cpp - Lowering Matmul Op --------------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -8,6 +8,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXMatMulOpLowering : public ConversionPattern { struct ONNXMatMulOpLowering : public ConversionPattern {
ONNXMatMulOpLowering(MLIRContext *ctx) ONNXMatMulOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,4 @@
//===----- reduction.inc - Lowering Reduction Ops -------------------------===// //===----- reduction.cpp - Lowering Reduction Ops -------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -8,6 +8,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
// Identity values // Identity values
template <> template <>
float getIdentityValue<float, ONNXReduceMaxOp>(){ float getIdentityValue<float, ONNXReduceMaxOp>(){

View File

@ -1,4 +1,4 @@
//===----- softmax.inc - Softmax Op ---------------------------------------===// //===----- softmax.cpp - Softmax Op ---------------------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -8,6 +8,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXSoftmaxOpLowering : public ConversionPattern { struct ONNXSoftmaxOpLowering : public ConversionPattern {
ONNXSoftmaxOpLowering(MLIRContext *ctx) ONNXSoftmaxOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,4 @@
//===----- conv.inc - Lowering Convolution Op -----------------------------===// //===----- conv.cpp - Lowering Convolution Op -----------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -8,13 +8,16 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXConvNoBiasOpLowering : public ConversionPattern { struct ONNXConvNoBiasOpLowering : public ConversionPattern {
ONNXConvNoBiasOpLowering(MLIRContext *ctx) ONNXConvNoBiasOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
PatternMatchResult PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
@ -25,12 +28,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, alloc = insertAllocAndDealloc(
{operands[0]}); memRefType, loc, rewriter, insertDealloc, {operands[0]});
auto resultShape = memRefType.getShape(); auto resultShape = memRefType.getShape();
auto inputShape = operands[0].getType().cast<MemRefType>().getShape(); auto &inputOperand = operands[0];
auto kernelShape = operands[1].getType().cast<MemRefType>().getShape(); auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
auto &kernelOperand = operands[1];
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
// R = ConvNoBias(D, K) // R = ConvNoBias(D, K)
// //
@ -91,123 +96,82 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
loc, FloatAttr::get(memRefType.getElementType(), 0)); loc, FloatAttr::get(memRefType.getElementType(), 0));
Value subchannels; Value subchannels;
if (kernelShape[1] < 0) { if (kernelShape[1] < 0) {
subchannels = subchannels = rewriter.create<DimOp>(loc, kernelOperand, 1).getResult();
rewriter.create<DimOp>(loc, operands[1], 1).getResult();
} else { } else {
subchannels = rewriter.create<ConstantIndexOp>( subchannels = rewriter.create<ConstantIndexOp>(loc, kernelShape[1]);
loc, kernelShape[1]);
} }
// 1. Define outer loops and emit empty optimization block: // 1. Define outer loops and emit empty optimization block:
int64_t nOuterLoops = (group > 1) ? 3 : 2; int64_t nOuterLoops = (group > 1) ? 3 : 2;
std::vector<Value> outerLoops; BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops);
std::vector<Value> optimizedOuterLoops; outerLoops.createDefineAndOptimizeOp();
Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops,
optimizedOuterLoops, nOuterLoops);
// Prepare iteration arguments over outer loop nest.
KrnlIterateOperandPack pack(
rewriter, outerLoops, optimizedOuterLoops);
// for n = 0 .. N: // for n = 0 .. N:
pack.pushConstantBound(0); int nIndex = outerLoops.pushBounds(0, inputOperand, 0);
if (inputShape[0] < 0)
pack.pushOperandBound(
rewriter.create<DimOp>(loc, operands[0], 0).getResult());
else
pack.pushConstantBound(inputShape[0]);
// for g = 0 .. N: // for g = 0 .. N:
if (group > 1) { int gIndex = -1;
pack.pushConstantBound(0); if (group > 1)
pack.pushConstantBound(group); gIndex = outerLoops.pushBounds(0, group);
}
// for m = 0 .. kernelsPerGroup: // for m = 0 .. kernelsPerGroup:
pack.pushConstantBound(0); int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
pack.pushConstantBound(kernelsPerGroup); // Outer loop iteration
// Outer loop iteration. outerLoops.createIterateOp();
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack); rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
Block &outerIterationBlock = iterateOp.bodyRegion().front();
// Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, outerLoops);
rewriter.setInsertionPointToStart(&outerIterationBlock);
{ {
// 2. Emit the body of the outer loop nest. // 2. Emit the body of the outer loop nest.
// 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m; // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
// If group is not set then the value of the kernel ID is // If group is not set then the value of the kernel ID is
// identical to that of the loop over kernels. // identical to that of the loop over kernels.
Value kernel = outerIterationBlock.getArguments()[1]; Value kernel = outerLoops.getInductionVar(mIndex);
if (group > 1) { if (group > 1) {
// Middle loop is over groups and third loop is over the // Middle loop is over groups and third loop is over the
// kernel identifiers in the current group. // kernel identifiers in the current group.
auto kernelsOffset = rewriter.create<MulIOp>(loc, auto kernelsOffset = rewriter.create<MulIOp>(
outerIterationBlock.getArguments()[1], loc, outerLoops.getInductionVar(gIndex), kernelsPerGroupValue);
kernelsPerGroupValue); kernel = rewriter.create<AddIOp>(
kernel = rewriter.create<AddIOp>(loc, kernelsOffset, loc, kernelsOffset, outerLoops.getInductionVar(mIndex));
outerIterationBlock.getArguments()[2]);
} }
// 2.2 Define spatial loops // 2.2 Define spatial loops
int64_t nSpatialLoops = resultShape.size() - 2; int64_t nSpatialLoops = resultShape.size() - 2;
std::vector<Value> spatialLoops; BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops);
std::vector<Value> optimizedSpatialLoops; spatialLoops.createDefineAndOptimizeOp();
Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops,
optimizedSpatialLoops, nSpatialLoops);
// 2.3 Prepare iteration arguments for spatial loop nest.
KrnlIterateOperandPack spatialPack(
rewriter, spatialLoops, optimizedSpatialLoops);
for (int i = 2; i < resultShape.size(); ++i) for (int i = 2; i < resultShape.size(); ++i)
addDimensionToPack(rewriter, loc, spatialPack, alloc, i); spatialLoops.pushBounds(0, alloc, i);
// 2.4 Emit loop nest over output spatial dimensions. // 2.4 Emit loop nest over output spatial dimensions.
// for rX = 0 .. RX // for rX = 0 .. RX
auto spatialIterateOp = spatialLoops.createIterateOp();
rewriter.create<KrnlIterateOp>(loc, spatialPack); rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock());
Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front();
// 2.5 Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optSpatialLoopBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, spatialLoops);
rewriter.setInsertionPointToStart(&spatialIterationBlock);
{ {
// 3. Emit the body of the spatial loop nest. // 3. Emit the body of the spatial loop nest.
// 3.1 Emit: R[n][kernel][r1][r2] = 0; // 3.1 Emit: R[n][kernel][r1][r2] = 0;
SmallVector<Value, 4> resultIndices; SmallVector<Value, 4> resultIndices;
// n // n
resultIndices.emplace_back(outerIterationBlock.getArguments()[0]); resultIndices.emplace_back(outerLoops.getInductionVar(nIndex));
// kernel // kernel
resultIndices.emplace_back(kernel); resultIndices.emplace_back(kernel);
// rX // rX
for (auto arg : spatialIterationBlock.getArguments()) for (auto arg : spatialLoops.getIterateBlock()->getArguments())
resultIndices.emplace_back(arg); resultIndices.emplace_back(arg);
// Store initializer value into output location. // Store initializer value into output location.
rewriter.create<StoreOp>(loc, zero, alloc, resultIndices); rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
// 3.2 Define inner loops. // 3.2 Define inner loops.
int64_t nInnerLoops = 1 + (kernelShape.size() - 2); int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
std::vector<Value> innerLoops; BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops);
std::vector<Value> optimizedInnerLoops; innerLoops.createDefineAndOptimizeOp();
Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops,
optimizedInnerLoops, nInnerLoops);
// 3.3 Prepare iteration arguments for inner loop nest.
KrnlIterateOperandPack innerPack(
rewriter, innerLoops, optimizedInnerLoops);
// for c = 0 .. C/group // for c = 0 .. C/group
innerPack.pushConstantBound(0); int cIndex = innerLoops.pushBounds(0, kernelShape[1]);
innerPack.pushConstantBound(kernelShape[1]);
// for Kx = 0 .. KX // for Kx = 0 .. KX
for (int i = 2; i < kernelShape.size(); ++i) for (int i = 2; i < kernelShape.size(); ++i)
addDimensionToPack(rewriter, loc, innerPack, operands[1], i); innerLoops.pushBounds(0, kernelOperand, i);
// 3.4 Emit inner loop nest. // 3.4 Emit inner loop nest.
auto innerIterateOp = innerLoops.createIterateOp();
rewriter.create<KrnlIterateOp>(loc, innerPack); rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
Block &innerIterationBlock = innerIterateOp.bodyRegion().front();
// 3.5 Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optInnerLoopBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, innerLoops);
rewriter.setInsertionPointToStart(&innerIterationBlock);
{ {
// 4. Emit inner loop body // 4. Emit inner loop body
// R[n][kernel][r1][r2] = // R[n][kernel][r1][r2] =
@ -217,13 +181,13 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
// 4.1 Prepare indices for accesing the data tensor. // 4.1 Prepare indices for accesing the data tensor.
SmallVector<Value, 4> dataIndices; SmallVector<Value, 4> dataIndices;
// n // n
dataIndices.emplace_back(outerIterationBlock.getArguments()[0]); dataIndices.emplace_back(outerLoops.getInductionVar(nIndex));
// g * (C / group) + c // g * (C / group) + c
Value channelDepth = innerIterationBlock.getArguments()[0]; Value channelDepth = innerLoops.getInductionVar(cIndex);
if (group > 1) if (group > 1)
channelDepth = rewriter.create<AddIOp>(loc, channelDepth, channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
rewriter.create<MulIOp>(loc, subchannels, rewriter.create<MulIOp>(
outerIterationBlock.getArguments()[1])); loc, subchannels, outerLoops.getInductionVar(gIndex)));
dataIndices.emplace_back(channelDepth); dataIndices.emplace_back(channelDepth);
// sX * rX + kX // sX * rX + kX
auto stridesAttribute = convOp.stridesAttr(); auto stridesAttribute = convOp.stridesAttr();
@ -233,15 +197,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
for (auto stride : stridesAttribute.getValue()) for (auto stride : stridesAttribute.getValue())
strides.emplace_back(stride.cast<IntegerAttr>().getInt()); strides.emplace_back(stride.cast<IntegerAttr>().getInt());
for (int i = 0; i < kernelShape.size() - 2; ++i) { for (int i = 0; i < kernelShape.size() - 2; ++i) {
Value spatialIndex = spatialIterationBlock.getArguments()[i]; Value spatialIndex = spatialLoops.getInductionVar(i);
// If strides are present then emit the correct access index. // If strides are present then emit the correct access index.
if (stridesAttribute && strides[i] > 1) if (stridesAttribute && strides[i] > 1)
spatialIndex = rewriter.create<MulIOp>(loc, spatialIndex = rewriter.create<MulIOp>(loc,
rewriter.create<ConstantIndexOp>(loc, strides[i]), rewriter.create<ConstantIndexOp>(loc, strides[i]),
spatialIterationBlock.getArguments()[i]); spatialLoops.getInductionVar(i));
dataIndices.emplace_back( dataIndices.emplace_back(rewriter.create<AddIOp>(
rewriter.create<AddIOp>(loc, spatialIndex, loc, spatialIndex, innerLoops.getInductionVar(i + 1)));
innerIterationBlock.getArguments()[i+1]));
} }
// 4.2 Prepare indices for accessing the kernel tensor. // 4.2 Prepare indices for accessing the kernel tensor.
@ -249,17 +212,16 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
// kernel // kernel
kernelIndices.emplace_back(kernel); kernelIndices.emplace_back(kernel);
// c // c
kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]); kernelIndices.emplace_back(innerLoops.getInductionVar(cIndex));
// kX // kX
for (int i = 0; i < kernelShape.size() - 2; ++i) for (int i = 0; i < kernelShape.size() - 2; ++i)
kernelIndices.emplace_back( kernelIndices.emplace_back(innerLoops.getInductionVar(i + 1));
innerIterationBlock.getArguments()[i+1]);
// 4.3 Compute convolution. // 4.3 Compute convolution.
auto loadData = auto loadData =
rewriter.create<LoadOp>(loc, operands[0], dataIndices); rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
auto loadKernel = auto loadKernel =
rewriter.create<LoadOp>(loc, operands[1], kernelIndices); rewriter.create<LoadOp>(loc, kernelOperand, kernelIndices);
auto loadPartialSum = auto loadPartialSum =
rewriter.create<LoadOp>(loc, alloc, resultIndices); rewriter.create<LoadOp>(loc, alloc, resultIndices);
Value result = rewriter.create<AddFOp>(loc, loadPartialSum, Value result = rewriter.create<AddFOp>(loc, loadPartialSum,

View File

@ -1,4 +1,4 @@
//===----- normalization.inc - Lowering Normalization Ops -----------------===// //===----- normalization.cpp - Lowering Normalization Ops -----------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -8,6 +8,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern { struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx) ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
: ConversionPattern( : ConversionPattern(

View File

@ -0,0 +1,324 @@
//====-- onnx_to_krnl_common.cpp - ONNX dialects to Krnl lowering ---------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file contains common code shared by the functions performing the
// lowering to the KRNL dialect.
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
/// Check is all dimensions are known at compile time.
bool hasAllConstantDimensions(MemRefType type) {
auto memRefShape = type.getShape();
for (int i = 0; i < memRefShape.size(); ++i)
if (memRefShape[i] < 0)
return false;
return true;
}
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
MemRefType convertToMemRefType(Type type) {
MemRefType memRefType;
auto tensorType = type.dyn_cast<TensorType>();
if (tensorType) {
assert(tensorType.hasRank() && "expected only ranked shapes");
memRefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
} else {
memRefType = type.dyn_cast<MemRefType>();
}
return memRefType;
}
/// Insert an allocation and deallocation for the given MemRefType.
Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter,
bool insertDealloc,
ArrayRef<Value> operands) {
// Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc;
if (!operands.empty()) {
auto memRefShape = type.getShape();
auto rank = memRefShape.size();
std::map<int, Value> fromOperands;
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
int memRefDimIdx = rank - 1 - reversedIdx;
if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
Value maxDim = nullptr;
for (int i = 0; i < operands.size(); i++) {
auto operandShape =
operands[i].getType().cast<MemRefType>().getShape();
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
if (operandDimIdx < 0)
continue;
// In case of operations with broadcasting, the dimension of the
// alloc result is the maximum size along each dimension of the
// operands.
auto operandDim =
rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
if (maxDim) {
auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
operandDim, maxDim);
maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
maxDim);
} else {
maxDim = operandDim;
}
}
fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
}
}
SmallVector<Value, 4> allocOperands;
for (int i = 0; i < rank; ++i)
if (memRefShape[i] < 0)
allocOperands.push_back(fromOperands[i]);
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
} else {
alloc = rewriter.create<AllocOp>(loc, type);
}
// Make sure to allocate at the beginning of the block if
// all dimensions are known.
auto *parentBlock = alloc.getOperation()->getBlock();
if (hasAllConstantDimensions(type))
alloc.getOperation()->moveBefore(&parentBlock->front());
if (insertDealloc) {
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
return alloc;
}
// Determine if current function returns the result value of the
// current op being lowered. If it does then dealloc should not be
// inserted.
bool checkInsertDealloc(Operation *currentOp) {
auto parentBlock = currentOp->getBlock();
bool insertDealloc = true;
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
assert(currentOp->getNumResults() < 2 &&
"No more than one result supported (for now).");
// If there is at least one result to investigate.
if (currentOp->getNumResults() > 0) {
auto result = currentOp->getResult(0);
for (const auto &operand : op.getOperands())
if (operand == result)
insertDealloc = false;
}
});
return insertDealloc;
}
// Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input
// type.
std::map<int64_t, int64_t>
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
std::map<int64_t, int64_t> OutInDimMap;
int64_t rank = inputTy.getRank();
// Mark reduction axes.
std::vector<bool> isReductionAxis;
for (decltype(rank) i = 0; i < rank; ++i) {
if (std::find(axes.begin(), axes.end(), i) != axes.end())
isReductionAxis.push_back(true);
else
isReductionAxis.push_back(false);
}
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
// If it is a reduction axis, there is no relationship among dimensions.
if (isReductionAxis[inIndex]) {
if (keepdims)
outIndex++;
} else {
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
outIndex++;
}
}
return OutInDimMap;
}
// Add bounds associated with the op operand to the KRNL iteration pack.
// Dynamic dimenions are supported.
void addDimensionToPack(ConversionPatternRewriter &rewriter,
Location loc, KrnlIterateOperandPack &pack,
Value operand, int index) {
auto shape = operand.getType().cast<MemRefType>().getShape();
if (shape[index] < 0) {
pack.pushConstantBound(0);
pack.pushOperandBound(
rewriter.create<DimOp>(loc, operand, index).getResult());
} else {
pack.pushConstantBound(0);
pack.pushConstantBound(shape[index]);
}
}
// Function that defines the KRNL dialect loops and their respective
// optimized version.
KrnlOptimizeLoopsOp
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops, int64_t numLoops) {
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
loops.reserve(numLoops);
for (auto result : loopsOp.getResults())
loops.push_back(result);
// Define optimized version of the loops.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
optimizedLoops.reserve(numLoops);
for (auto result : optimizedLoopsOp.getResults())
optimizedLoops.push_back(result);
return optimizedLoopsOp;
}
// Function that emits the loops and their optimized version.
// The function returns a reference to the inner optimization block.
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops,
int64_t numLoops) {
KrnlOptimizeLoopsOp optimizedLoopsOp =
emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
return &optimizedLoopsOp.region().front();
}
// Function which emits a basic set of loops and optimized loops
// for a given operation argument. A reference to the loop optimization
// block is returned in the last argument of the function.
void emitKrnlLoopsAndIterationForOperand(
ConversionPatternRewriter &rewriter, Location loc, Value operand,
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
KrnlIterateOp &iterateOp) {
// Operand shape.
auto shape = operand.getType().cast<MemRefType>().getShape();
// Number of loops.
int64_t rank = shape.size();
// Define loops and optimized loops.
std::vector<Value> optimizedLoops;
optimizedLoopsOp =
emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest.
for (int i = 0; i < rank; ++i)
addDimensionToPack(rewriter, loc, pack, operand, i);
iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
}
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
auto elementType = memRefType.getElementType();
unsigned sizeInBits;
if (elementType.isIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
} else {
auto vectorType = elementType.cast<VectorType>();
sizeInBits =
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
}
return llvm::divideCeil(sizeInBits, 8);
}
// Get run-time dimension information for unknown dimensions used for
// broadcasting.
std::map<int, std::map<int, Value>>
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
MemRefType memRefType, ArrayRef<Value> operands) {
auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size();
// For unknown dimensions, we need to get dimension values at runtime in
// order to do broadcasting.
std::map<int, std::map<int, Value>> DimInfo;
// For each result dimension, compute the number of sharing operands.
// Sharing operands are operands sharing the same index (counting from the
// rightmost to the leftmost) for a given dimension.
std::map<int, int> sharedDimCount;
for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
int dimIdx = rank - 1 - reversedIdx;
sharedDimCount[dimIdx] = 0;
for (int i = 0; i < operands.size(); ++i) {
auto shape = operands[i].getType().cast<MemRefType>().getShape();
if (reversedIdx <= shape.size() - 1)
sharedDimCount[dimIdx]++;
}
}
// An unknown dimension can have a value of 1 or N (N > 1).
// If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension.
// We only care about unknown dimensions whose number of sharing operands is
// more than one, since they are potentially broadcasted dimensions.
for (int i = 0; i < operands.size(); ++i) {
std::map<int, Value> broadcastedDims;
auto shape = operands[i].getType().cast<MemRefType>().getShape();
int size = shape.size();
for (int j = 0; j < shape.size(); ++j) {
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
auto isBroadcasted =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
}
}
DimInfo.insert(std::make_pair(i, broadcastedDims));
}
return DimInfo;
}
// Extract induction variables that are used for broadcasting values of a
// given operand.
std::vector<Value>
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value> broadcastedDims) {
// `operand` must has a ranked type. This should have been checked by the
// shape inference pass.
auto operandShape = operand.getType().cast<MemRefType>().getShape();
auto rank = operandShape.size();
auto loopCount = loopIVs.size();
std::vector<Value> newLoopIVs;
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
auto dimIdx = rank - 1 - reversedIdx;
auto loopIdx = loopCount - 1 - reversedIdx;
if (operandShape[dimIdx] == 1) {
// Broadcasted dimension
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
newLoopIVs.insert(newLoopIVs.begin(), zero);
} else if ((operandShape[dimIdx] == -1) &&
(broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
// Unknown dimension, it can have a value of 1 or N (N > 1).
// If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension.
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
loopIVs[loopIdx]);
newLoopIVs.insert(newLoopIVs.begin(), idx);
} else {
// Non-broadcasted dimension
newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
}
}
return newLoopIVs;
}

View File

@ -0,0 +1,217 @@
//====-- onnx_to_krnl_common.hpp - ONNX dialects to Krnl lowering ---------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file contains common code shared by the functions performing the
// lowering to the KRNL dialect.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <map>
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "mlir/IR/PatternMatch.h"
#include "src/dialect/krnl/krnl_helper.hpp"
#include "src/dialect/krnl/krnl_ops.hpp"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "src/pass/passes.hpp"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Common functions used when lowering the ONNX frontend dialect to KRNL.
//===----------------------------------------------------------------------===//
/// Check is all dimensions are known at compile time.
bool hasAllConstantDimensions(MemRefType type);
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
MemRefType convertToMemRefType(Type type);
/// Insert an allocation and deallocation for the given MemRefType.
Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter,
bool insertDealloc,
ArrayRef<Value> operands = {});
// Determine if current function returns the result value of the
// current op being lowered. If it does then dealloc should not be
// inserted.
bool checkInsertDealloc(Operation *currentOp);
// Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input
// type.
std::map<int64_t, int64_t>
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
// Add bounds associated with the op operand to the KRNL iteration pack.
// Dynamic dimenions are supported.
void addDimensionToPack(ConversionPatternRewriter &rewriter,
Location loc, KrnlIterateOperandPack &pack,
Value operand, int index);
// Function that defines the KRNL dialect loops and their respective
// optimized version.
KrnlOptimizeLoopsOp
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops, int64_t numLoops);
// Function that emits the loops and their optimized version.
// The function returns a reference to the inner optimization block.
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops,
std::vector<Value> &optimizedLoops,
int64_t numLoops);
// Function which emits a basic set of loops and optimized loops
// for a given operation argument. A reference to the loop optimization
// block is returned in the last argument of the function.
void emitKrnlLoopsAndIterationForOperand(
ConversionPatternRewriter &rewriter, Location loc, Value operand,
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
KrnlIterateOp &iterateOp);
unsigned getMemRefEltSizeInBytes(MemRefType memRefType);
// Get run-time dimension information for unknown dimensions used for
// broadcasting.
std::map<int, std::map<int, Value>>
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
MemRefType memRefType, ArrayRef<Value> operands);
// Extract induction variables that are used for broadcasting values of a
// given operand.
std::vector<Value>
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value> loopIVs, Value operand,
std::map<int, Value> broadcastedDims);
//===----------------------------------------------------------------------===//
// This is to get a scalar operation of a given type for a specific operation.
//===----------------------------------------------------------------------===//
template <typename Op>
struct ScalarOp {
using FOp = void;
using IOp = void;
};
template <typename FOp>
using ScalarFOp = typename ScalarOp<FOp>::FOp;
template <typename IOp>
using ScalarIOp = typename ScalarOp<IOp>::IOp;
// Get the identity element of a operation.
// Return NULL if the function does not have identity.
template <typename DataType, typename Op>
DataType getIdentityValue() {
return NULL;
}
//===----------------------------------------------------------------------===//
// This is used in the innermost loop of a KrnlIterateOp to insert computation
// composed of one or many scalar ops.
// Use template specialization for each of different ONNX operations.
//===----------------------------------------------------------------------===//
template <typename Op>
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();
Type element_type = operands.front().getType();
if (element_type.isa<IntegerType>()) {
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
mlir::None);
} else if (element_type.isa<FloatType>()) {
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
mlir::None);
} else {
emitError(loc, "unsupported element type");
return nullptr;
}
}
//===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type.
//===----------------------------------------------------------------------===//
struct TensorTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
TensorTypeConverter() {
addConversion(convertType);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
if (auto type = convertToMemRefType(t)) {
results.push_back(type);
return success();
}
results.push_back(t);
return success();
}
/// Return true if the inputs and outputs of the given function type are
/// legal. [Taken from MLIR and adapted to only check the legality of the
/// inputs. Once unranked results can be handled gracefully this
/// override needs to be removed in favour of the original MLIR one.]
bool isSignatureLegal(FunctionType funcType) {
return llvm::all_of(funcType.getInputs(),
[this](Type type) { return isLegal(type); });
}
};
//===----------------------------------------------------------------------===//
// Functions to add lowering patterns for frontend operations.
//===----------------------------------------------------------------------===//
// `math` directory methods:
void populateLoweringONNXElementwiseOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx);
void populateLoweringONNXMatMulOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXReductionOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXSoftmaxOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
// `nn` directory methods:
void populateLoweringONNXConvOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXNormalizationOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
// `tensor` directory methods:
void populateLoweringONNXUnsqueezeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXTransposeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXReshapeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXIdentityOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);

View File

@ -1,4 +1,4 @@
//===----- identity.inc - Lowering Identity Op ----------------------------===// //===----- identity.cpp - Lowering Identity Op ----------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -8,6 +8,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXIdentityOpLowering : public ConversionPattern { struct ONNXIdentityOpLowering : public ConversionPattern {
ONNXIdentityOpLowering(MLIRContext *ctx) ONNXIdentityOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,4 @@
//===----- reshape.inc - Lowering Reshape Op ------------------------------===// //===----- reshape.cpp - Lowering Reshape Op ------------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -8,6 +8,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXReshapeOpLowering : public ConversionPattern { struct ONNXReshapeOpLowering : public ConversionPattern {
ONNXReshapeOpLowering(MLIRContext *ctx) ONNXReshapeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,4 @@
//===----- transpose.inc - Lowering Transpose Op --------------------------===// //===----- transpose.cpp - Lowering Transpose Op --------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -8,6 +8,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXTransposeOpLowering : public ConversionPattern { struct ONNXTransposeOpLowering : public ConversionPattern {
ONNXTransposeOpLowering(MLIRContext *ctx) ONNXTransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,4 @@
//===----- unsqueeze.inc - Lowering Unsqueeze Op --------------------------===// //===----- unsqueeze.cpp - Lowering Unsqueeze Op --------------------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -8,6 +8,10 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
struct ONNXUnsqueezeOpLowering : public ConversionPattern { struct ONNXUnsqueezeOpLowering : public ConversionPattern {
ONNXUnsqueezeOpLowering(MLIRContext *ctx) ONNXUnsqueezeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}

View File

@ -1,4 +1,5 @@
#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExpr.h"
#include "src/dialect/krnl/krnl_ops.hpp" #include "src/dialect/krnl/krnl_ops.hpp"
@ -9,9 +10,8 @@ namespace onnf {
using namespace mlir; using namespace mlir;
ParseResult ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType, const Type &operandType, Value &operand) {
Value &operand) {
// If operand queue is empty, parse more operands and cache them. // If operand queue is empty, parse more operands and cache them.
if (_operandRefQueue.empty()) { if (_operandRefQueue.empty()) {
// Parse operand types: // Parse operand types:
@ -19,7 +19,7 @@ KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
_parser.parseOperandList(operand_refs); _parser.parseOperandList(operand_refs);
// Record operands: // Record operands:
for (auto& operand_ref : operand_refs) for (auto &operand_ref : operand_refs)
_operandRefQueue.emplace(operand_ref); _operandRefQueue.emplace(operand_ref);
} }
@ -48,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
return success(); return success();
} }
ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType, ParseResult KrnlDialectOperandParser::ParseOperand(
Value &operand) { const Type &operandType, Value &operand) {
if (ParseOptionalOperand(operandType, operand)) if (ParseOptionalOperand(operandType, operand))
return _parser.emitError( return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand."); _parser.getCurrentLocation(), "Expecting an operand.");
@ -65,8 +65,8 @@ ParseResult KrnlDialectOperandParser::ParseOperand(
return success(); return success();
} }
void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims, void printDimAndSymbolList(Operation::operand_iterator &begin, unsigned numDims,
unsigned numSymbols, OpAsmPrinter& p) { unsigned numSymbols, OpAsmPrinter &p) {
p << '('; p << '(';
p.printOperands(begin, begin + numDims); p.printOperands(begin, begin + numDims);
p << ')'; p << ')';
@ -81,8 +81,8 @@ void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
} }
void printBound(AffineMapAttr boundMap, void printBound(AffineMapAttr boundMap,
Operation::operand_iterator& boundOperandsBeg, const char* prefix, Operation::operand_iterator &boundOperandsBeg, const char *prefix,
OpAsmPrinter& p) { OpAsmPrinter &p) {
AffineMap map = boundMap.getValue(); AffineMap map = boundMap.getValue();
// Check if this bound should be printed using custom assembly form. // Check if this bound should be printed using custom assembly form.
@ -120,9 +120,10 @@ void printBound(AffineMapAttr boundMap,
printDimAndSymbolList( printDimAndSymbolList(
boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p); boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p);
} }
} // namespace onnf } // namespace onnf
namespace mlir { namespace mlir {
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) { void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
if (boundMaps.size() % 2 == 0) if (boundMaps.size() % 2 == 0)
_operands.emplace_back(inputLoops[boundMaps.size() / 2]); _operands.emplace_back(inputLoops[boundMaps.size() / 2]);
@ -130,11 +131,143 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
boundMaps.emplace_back(AffineMapAttr::get(map)); boundMaps.emplace_back(AffineMapAttr::get(map));
} }
void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) { void KrnlIterateOperandPack::pushOperandBound(Value operand) {
if (boundMaps.size() % 2 == 0) if (boundMaps.size() % 2 == 0)
_operands.emplace_back(inputLoops[boundMaps.size() / 2]); _operands.emplace_back(inputLoops[boundMaps.size() / 2]);
AffineMap map = builder.getSymbolIdentityMap(); AffineMap map = builder.getSymbolIdentityMap();
boundMaps.emplace_back(AffineMapAttr::get(map)); boundMaps.emplace_back(AffineMapAttr::get(map));
_operands.emplace_back(operand); _operands.emplace_back(operand);
} }
} // namespace mlir
BuildKrnlLoop::BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, int loopNum)
: rewriter(rewriter), loc(loc), originalLoopNum(loopNum), pack(NULL),
pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
createdIterateOp(false) {
if (originalLoopNum <= 0)
emitError(loc, "Expected positive number of original loops.");
}
BuildKrnlLoop::BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand)
: BuildKrnlLoop(rewriter, loc,
memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
BuildKrnlLoop::~BuildKrnlLoop() {
if (pack)
free(pack);
}
void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
// Insert define loop operation.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
originalLoops.reserve(originalLoopNum);
for (auto result : loopsOp.getResults())
originalLoops.push_back(result);
createdDefineOp = true;
// Insert optimize loop operation.
auto optimizedLoopsOp =
rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
optLoops.reserve(originalLoopNum);
// Emit empty optimizations if flag is set.
if (withEmptyOptimization) {
for (auto result : optimizedLoopsOp.getResults())
optLoops.push_back(result);
optBlock = &optimizedLoopsOp.region().front();
auto ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToEnd(optBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.restoreInsertionPoint(ip);
}
createdOptimizeOp = true;
// prepare data structure to push bounds
pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
}
int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
pack->pushConstantBound(lowerBound);
pack->pushConstantBound(upperBound);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) {
pack->pushConstantBound(lowerBound);
pack->pushOperandBound(upperBound);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
pack->pushConstantBound(lowerBound);
// Process upperBound as a dimension of the MemRef. Non-constant dimensions
// are supported.
auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
if (shape[upperBoundMemRefIndex] < 0) {
if (upperBoundMustBeConstant)
emitError(loc, "Bound expected to be constant.");
pack->pushOperandBound(
rewriter
.create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
.getResult());
} else
pack->pushConstantBound(shape[upperBoundMemRefIndex]);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) {
pack->pushOperandBound(lowerBound);
pack->pushOperandBound(upperBound);
return pushCount++;
}
void BuildKrnlLoop::createIterateOp() {
// Loop definition operation is mandatory.
if (!createdDefineOp)
emitError(loc, "Must create define op before iterate op.");
// Loop optimization operation is mandatory (for now).
if (!createdOptimizeOp)
emitError(loc, "Must create optimize op before iterate op.");
// Check if all bounds have been defined.
if (pushCount != originalLoopNum)
emitError(loc, "Must push bounds for all original loops.");
// Emit iteration operation.
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
iterBlock = &iterateOp.bodyRegion().front();
createdIterateOp = true;
}
void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
Value memRefOperand, bool withEmptyOptimization) {
// Rank of the MemRef operand. We will emit a loop for each dimension.
int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
if (originalLoopNum != loopNum)
emitError(loc, "Mismatch in loop numbers from constructor and define.");
// Emit the definition and the optimization operations for the loop nest.
createDefineAndOptimizeOp(withEmptyOptimization);
// Push a lower-upper bound pair for each dimension of the MemRef operand.
// The lower bound in this case is always zero.
for (int i = 0; i < originalLoopNum; ++i)
pushBounds(0, memRefOperand, i);
// Emit the iteration operation over the current loop nest.
createIterateOp();
}
BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
// Check if loop iteration variable is within bounds.
if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
emitError(loc, "Original loop index is out of bounds.");
return iterBlock->getArguments()[originalLoopIndex];
}
} // namespace mlir

View File

@ -8,39 +8,38 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnf { namespace onnf {
class KrnlDialectOperandParser { class KrnlDialectOperandParser {
public: public:
explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser) explicit KrnlDialectOperandParser(mlir::OpAsmParser &parser)
: _parser(parser), _builder(parser.getBuilder()){}; : _parser(parser), _builder(parser.getBuilder()){};
// Parse an optional operand. // Parse an optional operand.
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType, mlir::ParseResult ParseOptionalOperand(
mlir::Value &operand); const mlir::Type &operandType, mlir::Value &operand);
// Parse an optional operand and push it to an operand list. // Parse an optional operand and push it to an operand list.
mlir::ParseResult mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
ParseOptionalOperand(const mlir::Type &operandType, llvm::SmallVectorImpl<mlir::Value> &operandList);
llvm::SmallVectorImpl<mlir::Value> &operandList);
// Parse a required operand. // Parse a required operand.
mlir::ParseResult ParseOperand(const mlir::Type &operandType, mlir::ParseResult ParseOperand(
mlir::Value &operand); const mlir::Type &operandType, mlir::Value &operand);
// Parse a required operand and push it to an operand list. // Parse a required operand and push it to an operand list.
mlir::ParseResult mlir::ParseResult ParseOperand(const mlir::Type &operandType,
ParseOperand(const mlir::Type &operandType, llvm::SmallVectorImpl<mlir::Value> &operandList);
llvm::SmallVectorImpl<mlir::Value> &operandList);
// Do we have more operands to parse? // Do we have more operands to parse?
bool hasOperandLeft() { return !_operandRefQueue.empty(); } bool hasOperandLeft() { return !_operandRefQueue.empty(); }
private: private:
mlir::OpAsmParser& _parser; mlir::OpAsmParser &_parser;
mlir::Builder& _builder; mlir::Builder &_builder;
// A queue storing the parsed SSA id references. // A queue storing the parsed SSA id references.
std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue; std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
@ -50,24 +49,24 @@ class KrnlDialectOperandParser {
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197 // https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
// Main difference is that it advances the iterator `begin` as it consumes // Main difference is that it advances the iterator `begin` as it consumes
// dimension and symbol operands. // dimension and symbol operands.
void printDimAndSymbolList(mlir::Operation::operand_iterator& begin, void printDimAndSymbolList(mlir::Operation::operand_iterator &begin,
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p); unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter &p);
// Adapted from: // Adapted from:
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272 // https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
// Main difference is that it advances the iterator `boundOperandsBeg` as it // Main difference is that it advances the iterator `boundOperandsBeg` as it
// prints bound. // prints bound.
void printBound(mlir::AffineMapAttr boundMap, void printBound(mlir::AffineMapAttr boundMap,
mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix, mlir::Operation::operand_iterator &boundOperandsBeg, const char *prefix,
mlir::OpAsmPrinter& p); mlir::OpAsmPrinter &p);
} // namespace onnf } // namespace onnf
namespace mlir { namespace mlir {
struct KrnlIterateOperandPack { struct KrnlIterateOperandPack {
KrnlIterateOperandPack(mlir::Builder &builder, KrnlIterateOperandPack(mlir::Builder &builder,
llvm::ArrayRef<mlir::Value> inputLoops, llvm::ArrayRef<mlir::Value> inputLoops,
llvm::ArrayRef<mlir::Value> optimizedLoops) llvm::ArrayRef<mlir::Value> optimizedLoops)
: builder(builder), inputLoops(inputLoops), : builder(builder), inputLoops(inputLoops),
optimizedLoops(optimizedLoops) { optimizedLoops(optimizedLoops) {
_operands.insert( _operands.insert(
@ -88,7 +87,7 @@ struct KrnlIterateOperandPack {
size_t getNumInputLoops() const { return inputLoops.size(); } size_t getNumInputLoops() const { return inputLoops.size(); }
private: private:
int _boundIdx = 0; int _boundIdx = 0;
llvm::SmallVector<mlir::Value, 8> _operands; llvm::SmallVector<mlir::Value, 8> _operands;
@ -97,7 +96,124 @@ struct KrnlIterateOperandPack {
llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops; llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
mlir::Builder& builder; mlir::Builder &builder;
}; };
} // namespace mlir // Helper function to write kernel loops. This class will let us build a single
// define/optimize/iterate operation combo. We can then insert optimizations in
// the body of the optimization operation, and operations in the body of the
// iterate operation.
//
// The sequence is as follow:
//
// 1) Create an object giving the rewriter, location, and number of loop in
// the original (non optimized) loop.
//
// 2) Create define & optimize ops (currently paired). Optimizations can then
// be added to the inner block of the optimize operation. Make sure to set
// the insertion point to that block for optimizations to go in the right
// place.
//
// 3) Push the bounds for each of the original loops. Bounds are pushed in
// pairs (lower & upper bounds). There are a few methods to do it depending
// on the type of the bounds. When pushing bounds, the method returns a
// number that represent the index associated with that iteration (induction
// variable and bounds). That index can be used later to extract the
// induction variable for reference in computation and/or index calculations
// of mem refs.
//
// 4) Once all the bounds are pushed, create the iterate operation. Once this
// is done, we can add operations within the iterate blocks by setting the
// insertion point to it. Value of the induction variables can be retrieved
// using the proper index (determined when pushin the bounds).
class BuildKrnlLoop {
public:
// Create kernel loop builder for a loop nest of depth loopNum.
BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
// Create kernel loop builder for a loop nest of depth equal to the
// dimensionality of the operand. An operand of MemRef type is requied.
BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
~BuildKrnlLoop();
// Create define and optimize loop with loopNum original loops. If
// withEmptyOptimization is true, the optimization is simply the identity
// function (no optimizations).
void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
// Push bounds (lower and upper) for each of the loops (order matters).
// The function returns the order number associated with the loop iteration.
// This index is used by the getInductionVar call. Non-constant operands
// must be of MemRef type.
int pushBounds(int64_t lowerBound, int64_t upperBound);
int pushBounds(int64_t lowerBound, Value upperBound);
int pushBounds(Value lowerBound, Value upperBound);
int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
// Create the KrnlIterateOp assiciated with this loop nest. The loops
// iteration will be created if the definition and the optimization
// operations associated with this loop nest have been emitted already.
void createIterateOp();
// Create the loop nest definition, optimization and iteration operations
// for a given operand of MemRef type. The loop nest has a depth equal to the
// rank of the MemRef operand. The lower bound of each loop is zero. The
// upper bound of each loop is given by the corresponding dimension of the
// MemRef operand.
void createDefineOptimizeAndIterateOp(
Value memRefOperand, bool withEmptyOptimization = true);
// Get the (original loop) induction variable associated with the given
// index. Use the index returned when pushing the bounds.
BlockArgument &getInductionVar(int originalLoopIndex);
// Get a reference to the code region of the optimization operation.
// This allows us to set the insertion point to the inner block of the
// loop nest optimization operation.
Block *getOptimizationBlock() { return optBlock; }
// Get a reference to the code region of the iteration operation.
// This allows us to set the insertion point to the inner block of the
// loop nest iteration operation.
Block *getIterateBlock() { return iterBlock; }
// Get original loop nest.
std::vector<Value> &getOriginalLoops() { return originalLoops; }
// Get optimized loop nest.
std::vector<Value> &getOptimizedLoops() { return optLoops; }
private:
// Required for emitting operations.
ConversionPatternRewriter &rewriter;
Location loc;
int originalLoopNum;
// List of original, un-optimized loops.
std::vector<Value> originalLoops;
// List of optimized loops.
std::vector<Value> optLoops;
// List of lower-upper bound pairs needed by the KrnlIterateOp.
KrnlIterateOperandPack *pack;
// Number of lower-upper bound pairs pushed.
int pushCount;
// Flags that keep track of emitted operations.
bool createdDefineOp;
bool createdOptimizeOp;
bool createdIterateOp;
// Saved insertion point in the code region of the KrnlOptimizeLoopsOp.
Block *optBlock;
// Saved insertion point in the code region of the KrnlIterateOp.
Block *iterBlock;
};
} // namespace mlir

View File

@ -90,25 +90,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
// or outputs. This decision affects only ONNX operations with optional // or outputs. This decision affects only ONNX operations with optional
// arguments not ONNX operations with variadic operands. // arguments not ONNX operations with variadic operands.
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation without bias.";
let description = [{
The "onnx.Gemm" generic matrix multiplication without bias.
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
DefaultValuedAttr<F32Attr, "1.0">:$alpha,
DefaultValuedAttr<F32Attr, "1.0">:$beta,
DefaultValuedAttr<I64Attr, "0">:$transA,
DefaultValuedAttr<I64Attr, "0">:$transB);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
}
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1; let hasCanonicalizer = 1;

View File

@ -24,12 +24,29 @@
using namespace mlir; using namespace mlir;
using namespace mlir::OpTrait::util; using namespace mlir::OpTrait::util;
//===----------------------------------------------------------------------===//
// ONNX Helper functions
//===----------------------------------------------------------------------===//
static size_t ArrayAttrSize(ArrayAttr a) { return a.size(); }
static size_t ArrayAttrSize(Optional<ArrayAttr> a) {
return a.getValue().size();
}
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
}
static int64_t ArrayAttrIntVal(Optional<ArrayAttr> a, int i) {
return (a.getValue().getValue()[i]).cast<IntegerAttr>().getInt();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Get reduction type // Get reduction type
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
RankedTensorType getReductionOutputType(RankedTensorType operandTy, RankedTensorType getReductionOutputType(
Optional<ArrayAttr> axesAttrs, RankedTensorType operandTy, Optional<ArrayAttr> axesAttrs, APInt keepdims) {
APInt keepdims) {
int64_t rank = operandTy.getRank(); int64_t rank = operandTy.getRank();
SmallVector<int64_t, 4> axes; SmallVector<int64_t, 4> axes;
@ -87,19 +104,18 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
} }
void ONNXEntryPointOp::build(mlir::Builder *builder, void ONNXEntryPointOp::build(mlir::Builder *builder,
mlir::OperationState &state, mlir::FuncOp function, mlir::OperationState &state, mlir::FuncOp function, int numInputs,
int numInputs, int numOutputs) { int numOutputs) {
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(), state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
builder->getSymbolRefAttr(function)); builder->getSymbolRefAttr(function));
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(), state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
builder->getI32IntegerAttr(numInputs)); builder->getI32IntegerAttr(numInputs));
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(), state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
builder->getI32IntegerAttr(numOutputs)); builder->getI32IntegerAttr(numOutputs));
} }
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location, ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
mlir::FuncOp &func, int numInputs, mlir::FuncOp &func, int numInputs, int numOutputs) {
int numOutputs) {
mlir::OperationState state(location, "onnx.EntryPoint"); mlir::OperationState state(location, "onnx.EntryPoint");
Builder builder(location->getContext()); Builder builder(location->getContext());
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs); mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
@ -120,25 +136,19 @@ void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
// Tanh // Tanh
/// Infer the output shape of the ONNXTanhOp. This method is required by the /// Infer the output shape of the ONNXTanhOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXTanhOp::inferShapes() { void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sinh // Sinh
/// Infer the output shape of the ONNXSinhOp. This method is required by the /// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXSinhOp::inferShapes() { void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Cosh // Cosh
/// Infer the output shape of the ONNXCoshOp. This method is required by the /// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXCoshOp::inferShapes() { void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Cos // Cos
@ -178,9 +188,7 @@ void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
// Relu // Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the /// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXReluOp::inferShapes() { void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LeakyRelu // LeakyRelu
@ -194,9 +202,7 @@ void ONNXLeakyReluOp::inferShapes() {
// Selu // Selu
/// Infer the output shape of the ONNXSeluOp. This method is required by /// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSeluOp::inferShapes() { void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Reciprocal // Reciprocal
@ -234,17 +240,13 @@ void ONNXSoftsignOp::inferShapes() {
// Sqrt // Sqrt
/// Infer the output shape of the ONNXSqrtOp. This method is required by /// Infer the output shape of the ONNXSqrtOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSqrtOp::inferShapes() { void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sign // Sign
/// Infer the output shape of the ONNXSignOp. This method is required by /// Infer the output shape of the ONNXSignOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSignOp::inferShapes() { void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Add // Add
@ -404,12 +406,12 @@ void ONNXIdentityOp::inferShapes() {
void ONNXMatMulOp::inferShapes() { void ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!A().getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !B().getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = B().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
auto lhsShape = lhsTy.getShape(); auto lhsShape = lhsTy.getShape();
@ -417,15 +419,14 @@ void ONNXMatMulOp::inferShapes() {
if (lhsShape.size() < 1 && rhsShape.size() < 1) { if (lhsShape.size() < 1 && rhsShape.size() < 1) {
// Multiplication by scalars is not allowed. // Multiplication by scalars is not allowed.
emitError("Multiplication by scalar arguments not allowed."); emitError("Multiplication by scalar arguments not allowed");
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) { } else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
// Special case when both arrays are 1-dimensional and according to // Special case when both arrays are 1-dimensional and according to
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all // need to be removed after the multiplication but cannot be removed if all
// sizes are 1. // sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 && if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
lhsShape[0] != rhsShape[0]) emitError("Attempt to multiply incompatible matrices");
emitError("Attempt to multiply incompatible matrices.");
dims.emplace_back(1); dims.emplace_back(1);
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) { } else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
// If the first argument is 1-D, it is promoted to a matrix by prepending a // If the first argument is 1-D, it is promoted to a matrix by prepending a
@ -440,7 +441,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size(); unsigned rhsRank = rhsShape.size();
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 && if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[0] != rhsShape[rhsRank - 2]) lhsShape[0] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]); dims.emplace_back(rhsShape[i]);
@ -458,7 +459,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned lhsRank = lhsShape.size(); unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0]) lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i) for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
dims.emplace_back(lhsShape[i]); dims.emplace_back(lhsShape[i]);
@ -472,7 +473,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned lhsRank = lhsShape.size(); unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0]) lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i) for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
dims.emplace_back(lhsShape[i]); dims.emplace_back(lhsShape[i]);
@ -486,7 +487,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size(); unsigned rhsRank = rhsShape.size();
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 && if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[1] != rhsShape[rhsRank - 2]) lhsShape[1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]); dims.emplace_back(rhsShape[i]);
@ -502,7 +503,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size(); unsigned rhsRank = rhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 && if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2]) lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
// Check and perform broadcasting for the shapes. // Check and perform broadcasting for the shapes.
SmallVector<int64_t, 2> lhsBcastShape; SmallVector<int64_t, 2> lhsBcastShape;
@ -512,7 +513,7 @@ void ONNXMatMulOp::inferShapes() {
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
rhsBcastShape.emplace_back(rhsShape[i]); rhsBcastShape.emplace_back(rhsShape[i]);
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
emitError("Broadcasted dimensions are incompatible."); emitError("Broadcasted dimensions are incompatible");
dims.emplace_back(lhsShape[lhsRank - 2]); dims.emplace_back(lhsShape[lhsRank - 2]);
dims.emplace_back(rhsShape[rhsRank - 1]); dims.emplace_back(rhsShape[rhsRank - 1]);
@ -527,7 +528,7 @@ void ONNXMatMulOp::inferShapes() {
// Check legality of matrix multiplication. // Check legality of matrix multiplication.
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim) if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
if (rhsShape.size() > 1) if (rhsShape.size() > 1)
dims.emplace_back(rhsShape[1]); dims.emplace_back(rhsShape[1]);
@ -541,14 +542,14 @@ void ONNXMatMulOp::inferShapes() {
// Gemm // Gemm
void ONNXGemmOp::inferShapes() { void ONNXGemmOp::inferShapes() {
bool hasBias = !C().getType().isa<NoneType>();
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!A().getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() || !B().getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>()) (hasBias && !C().getType().isa<RankedTensorType>()))
return; return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = B().getType().cast<RankedTensorType>();
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B; int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
@ -557,44 +558,21 @@ void ONNXGemmOp::inferShapes() {
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1]; K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) { if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
emitError("Tensor shapes mismatched."); emitError("Tensor shapes mismatched");
} }
// Check whether bias is unidirectional broadcasting or not. if (hasBias) {
auto shape = biasTy.getShape(); // Check whether bias is unidirectional broadcasting or not.
int rank = shape.size(); auto biasTy = C().getType().cast<RankedTensorType>();
if ((rank > 2) || auto shape = biasTy.getShape();
(rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] && int rank = shape.size();
shape[rank - 1] != 1) || if ((rank > 2) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] && (rank >= 1 && shape[rank - 1] != -1 && N != -1 &&
shape[rank - 2] != 1)) { N != shape[rank - 1] && shape[rank - 1] != 1) ||
emitError("Bias shape mismatched."); (rank == 2 && shape[rank - 2] != -1 && M != -1 &&
} M != shape[rank - 2] && shape[rank - 2] != 1)) {
emitError("Bias shape mismatched");
SmallVector<int64_t, 2> dims; }
dims.emplace_back(M);
dims.emplace_back(N);
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
// GemmNoBias
void ONNXGemmNoBiasOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
emitError("Tensor shapes mismatched.");
} }
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
@ -606,50 +584,50 @@ void ONNXGemmNoBiasOp::inferShapes() {
/// BatchNormalizationTestMode /// BatchNormalizationTestMode
void ONNXBatchNormalizationTestModeOp::inferShapes() { void ONNXBatchNormalizationTestModeOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!X().getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() || !scale().getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>() || !B().getType().isa<RankedTensorType>() ||
!getOperand(3).getType().isa<RankedTensorType>() || !mean().getType().isa<RankedTensorType>() ||
!getOperand(4).getType().isa<RankedTensorType>()) !var().getType().isa<RankedTensorType>())
return; return;
auto input = getOperand(0).getType().cast<RankedTensorType>(); auto inputTensorTy = X().getType().cast<RankedTensorType>();
auto scale = getOperand(1).getType().cast<RankedTensorType>(); auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
auto bias = getOperand(2).getType().cast<RankedTensorType>(); auto biasTensorTy = B().getType().cast<RankedTensorType>();
auto mean = getOperand(3).getType().cast<RankedTensorType>(); auto meanTensorTy = mean().getType().cast<RankedTensorType>();
auto variance = getOperand(4).getType().cast<RankedTensorType>(); auto varianceTensorTy = var().getType().cast<RankedTensorType>();
// Check whether the shapes of scale, bias, mean and variance are valid. // Check whether the shapes of scale, bias, mean and variance are valid.
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1. // In case of N, C is assumed to be 1.
// Shapes of scale, bias, mean and variance must be C. // Shapes of scale, bias, mean and variance must be C.
int64_t c = -1; int64_t c = -1;
if (input.getShape().size() == 1) { if (inputTensorTy.getShape().size() == 1) {
c = 1; c = 1;
} else if (input.getShape().size() > 2) { } else if (inputTensorTy.getShape().size() > 2) {
c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1; c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1;
} else { } else {
emitError("Wrong rank for the input."); emitError("Wrong rank for the input");
} }
if (c != -1) { if (c != -1) {
auto s = scale.getShape(); auto s = scaleTensorTy.getShape();
auto b = bias.getShape(); auto b = biasTensorTy.getShape();
auto m = mean.getShape(); auto m = meanTensorTy.getShape();
auto v = variance.getShape(); auto v = varianceTensorTy.getShape();
if ((s.size() != 1) || (s[0] != -1 && s[0] != c)) if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
emitError("Wrong rank for the scale."); emitError("Wrong rank for the scale");
if ((b.size() != 1) || (b[0] != -1 && b[0] != c)) if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
emitError("Wrong rank for the bias."); emitError("Wrong rank for the bias");
if ((m.size() != 1) || (m[0] != -1 && m[0] != c)) if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
emitError("Wrong rank for the mean."); emitError("Wrong rank for the mean");
if ((v.size() != 1) || (v[0] != -1 && v[0] != c)) if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
emitError("Wrong rank for the variance."); emitError("Wrong rank for the variance");
} }
// The output tensor of the same shape as the input. // The output tensor of the same shape as the input.
getResult().setType(getOperand(0).getType()); getResult().setType(X().getType());
} }
// TODO: // TODO:
@ -662,21 +640,21 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
void ONNXReshapeOp::inferShapes() { void ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified. // Cannot infer shape if no shape tensor is specified.
if (!getOperand(1).getType().isa<RankedTensorType>()) if (!shape().getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked");
auto inputTensorTy = getOperand(0).getType().cast<RankedTensorType>(); auto inputTensorTy = data().getType().cast<RankedTensorType>();
auto shapeTensorTy = getOperand(1).getType().cast<RankedTensorType>(); auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
// Only rank 1 shape tensors are supported. // Only rank 1 shape tensors are supported.
if (shapeTensorTy.getShape().size() != 1) if (shapeTensorTy.getShape().size() != 1)
emitError("Shape tensor must have rank one."); emitError("Shape tensor must have rank one");
int64_t outputRank = shapeTensorTy.getShape()[0]; int64_t outputRank = shapeTensorTy.getShape()[0];
// Shape tensor must have constant shape. // Shape tensor must have constant shape.
if (outputRank < 0) if (outputRank < 0)
emitError("Shape tensor must have constant shape."); emitError("Shape tensor must have constant shape");
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
for (int i = 0; i < outputRank; ++i) for (int i = 0; i < outputRank; ++i)
@ -692,12 +670,12 @@ void ONNXReshapeOp::inferShapes() {
void ONNXTransposeOp::inferShapes() { void ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand().getType().isa<RankedTensorType>()) if (!data().getType().isa<RankedTensorType>())
return; return;
// Naive transposition which handles the default case of // Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose). // reversing the shape of the tensor (similar to numpy.transpose).
auto arrayTy = getOperand().getType().cast<RankedTensorType>(); auto arrayTy = data().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
auto permutation = ONNXTransposeOp::permAttr(); auto permutation = ONNXTransposeOp::permAttr();
if (permutation) { if (permutation) {
@ -713,14 +691,13 @@ void ONNXTransposeOp::inferShapes() {
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReduceMax // ReduceMax
void ONNXReduceMaxOp::inferShapes() { void ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked");
return; return;
} }
@ -734,7 +711,7 @@ void ONNXReduceMaxOp::inferShapes() {
void ONNXReduceMinOp::inferShapes() { void ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked");
return; return;
} }
@ -748,7 +725,7 @@ void ONNXReduceMinOp::inferShapes() {
void ONNXReduceProdOp::inferShapes() { void ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked");
return; return;
} }
@ -762,7 +739,7 @@ void ONNXReduceProdOp::inferShapes() {
void ONNXReduceSumOp::inferShapes() { void ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked");
return; return;
} }
@ -781,30 +758,31 @@ void ONNXConvNoBiasOp::inferShapes() {
// W: (M x C/group x k1 x k2 x ... x kn) // W: (M x C/group x k1 x k2 x ... x kn)
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!X().getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !W().getType().isa<RankedTensorType>())
return; return;
auto dataTy = getOperand(0).getType().cast<RankedTensorType>(); auto dataTy = X().getType().cast<RankedTensorType>();
auto weightTy = getOperand(1).getType().cast<RankedTensorType>(); auto weightTy = W().getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape(); auto dataShape = dataTy.getShape();
auto weightShape = weightTy.getShape(); auto weightShape = weightTy.getShape();
// Lowest supported convolution is a one dimensional convolution. // Lowest supported convolution is a one dimensional convolution.
if (dataShape.size() < 3) if (dataShape.size() < 3)
emitError("Data input shape must be at least (NxCxD1)."); emitError("Data input shape must be at least (NxCxD1)");
// Check that shape of weight and data have same length. // Check that shape of weight and data have same length.
if (dataShape.size() != weightShape.size()) if (dataShape.size() != weightShape.size())
emitError("Weight size not compatible with data size."); emitError("Weight size not compatible with data size");
// Required attribute auto_pad defaults to NOTSET. // Required attribute auto_pad defaults to NOTSET.
auto autoPad = auto_pad(); auto autoPad = auto_pad();
// Group is a required attribute and should have default value of 1. // Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue(); int64_t group =
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (dataShape[1] != (weightShape[1] * group)) if (dataShape[1] != (weightShape[1] * group))
emitError("Channel dimension mismatch."); emitError("Channel dimension mismatch");
// Note: the value of the group attribut only impacts the way the // Note: the value of the group attribut only impacts the way the
// computation is carried out and not the actual output size. // computation is carried out and not the actual output size.
@ -834,11 +812,10 @@ void ONNXConvNoBiasOp::inferShapes() {
// argument. // argument.
SmallVector<int64_t, 2> kernelDims; SmallVector<int64_t, 2> kernelDims;
if (auto kernelShape = kernel_shapeAttr()) { if (auto kernelShape = kernel_shapeAttr()) {
if (kernelShape.getValue().size() != nDims) if (ArrayAttrSize(kernelShape) != nDims)
emitError("kernel_shape length incompatible with spatial dimensions."); emitError("kernel_shape length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i) for (int i = 0; i < nDims; ++i)
kernelDims.emplace_back( kernelDims.emplace_back(ArrayAttrIntVal(kernelShape, i));
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
} else { } else {
for (int i = 0; i < nDims; ++i) for (int i = 0; i < nDims; ++i)
kernelDims.emplace_back(weightShape[i + 2]); kernelDims.emplace_back(weightShape[i + 2]);
@ -856,11 +833,11 @@ void ONNXConvNoBiasOp::inferShapes() {
// From a dimensionality perspective the kernel size becomes the dilated // From a dimensionality perspective the kernel size becomes the dilated
// kernel size. // kernel size.
if (auto dilations = dilationsAttr()) { if (auto dilations = dilationsAttr()) {
if (dilations.getValue().size() != nDims) if (ArrayAttrSize(dilations) != nDims)
emitError("dilations length incompatible with spatial dimensions."); emitError("dilations length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i) for (int i = 0; i < nDims; ++i)
kernelDims[i] = (kernelDims[i] + 1) * kernelDims[i] =
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1; (kernelDims[i] + 1) * ArrayAttrIntVal(dilations, i) - 1;
} }
// Subtract kernel dimensions from input data dimensions. // Subtract kernel dimensions from input data dimensions.
@ -873,16 +850,14 @@ void ONNXConvNoBiasOp::inferShapes() {
// present then pads is considered to be all zeros (no padding). // present then pads is considered to be all zeros (no padding).
if (auto pads = padsAttr()) { if (auto pads = padsAttr()) {
// pads consists of two entries for each spatial axis. // pads consists of two entries for each spatial axis.
if (pads.getValue().size() != 2 * nDims) if (ArrayAttrSize(pads) != 2 * nDims)
emitError("pads size is not twice the spatial size."); emitError("pads size is not twice the spatial size");
for (int i = 0; i < nDims; ++i) { for (int i = 0; i < nDims; ++i) {
// Padding for beginning of axis. // Padding for beginning of axis.
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt(); outSpatialDims[i] += ArrayAttrIntVal(pads, i);
outSpatialDims[i] += p;
// Padding for end of axis. // Padding for end of axis.
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt(); outSpatialDims[i] += ArrayAttrIntVal(pads, i + nDims);
outSpatialDims[i] += p;
} }
} }
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
@ -898,16 +873,15 @@ void ONNXConvNoBiasOp::inferShapes() {
} else if (autoPad == "VALID") { } else if (autoPad == "VALID") {
// No padding // No padding
} else { } else {
emitError("Unexpected attribute value for auto_pad."); emitError("Unexpected attribute value for auto_pad");
} }
// Strides // Strides
if (auto strides = ONNXConvNoBiasOp::stridesAttr()) { if (auto strides = ONNXConvNoBiasOp::stridesAttr()) {
if (strides.getValue().size() != nDims) if (ArrayAttrSize(strides) != nDims)
emitError("strides length incompatible with spatial dimensions."); emitError("strides length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i) { for (int i = 0; i < nDims; ++i) {
int64_t stride = int64_t stride = ArrayAttrIntVal(strides, i);
strides.getValue()[i].cast<IntegerAttr>().getInt();
outSpatialDims[i] = floor(outSpatialDims[i] / stride); outSpatialDims[i] = floor(outSpatialDims[i] / stride);
} }
} }
@ -922,144 +896,140 @@ void ONNXConvNoBiasOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MaxPoolSingleOut // MaxPoolSingleOut
// Infer shape attributes output:
// - auto_pad set to NOTSET;
// - dilations, strides: set to 1 if not defined by user;
// - pads: set to proper value, 0 if not defined by user.
void ONNXMaxPoolSingleOutOp::inferShapes() { void ONNXMaxPoolSingleOutOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>()) if (!X().getType().isa<RankedTensorType>())
return; return;
auto builder = mlir::Builder(this->getContext());
// 1) get shape of input // 1) Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>(); auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape(); auto xShape = xTy.getShape();
auto xRank = xShape.size(); auto xRank = xShape.size();
// 2) analyse parameters // 2) Analyse parameters. Get kernel sizes from kernel_shape attribute.
// get kernel sizes from kernel_shape attribute
auto kernelShape = kernel_shape(); auto kernelShape = kernel_shape();
if (!kernelShape) if (!kernelShape)
emitError("kernel_shape is a mandatory attribute for which there is no default."); emitError(
auto kernelShapeArray = kernelShape.getValue(); "kernel_shape is a mandatory attribute for which there is no default");
auto kernelRank = kernelShape.size(); auto kernelRank = ArrayAttrSize(kernelShape);
if (kernelRank > xRank) if (kernelRank > xRank)
emitError("kernel_shape spatial dimension is too large."); emitError("kernel_shape spatial dimension is too large");
auto kernelOffset = xRank - kernelRank; auto kernelOffset = xRank - kernelRank;
// ceil mode // Ceil mode.
auto ceilMode = ceil_mode().getSExtValue(); auto ceilMode = ceil_mode().getSExtValue();
// dilatation // Dilatation.
SmallVector<int64_t, 4> actualDilations;
auto dilationsOpt = dilations(); auto dilationsOpt = dilations();
if (dilationsOpt.hasValue()) { if (dilationsOpt.hasValue()) {
auto dilationsArray = dilationsOpt.getValue().getValue(); // opt -> attr -> array if (ArrayAttrSize(dilationsOpt) != kernelRank)
if (dilationsArray.size() != kernelRank) emitError("dialation rank is not the same as the spatial rank");
emitError("dialation rank is not the same as the spatial rank."); // Test values.
// fill in the actual values
for (int i = 0; i < kernelRank; ++i) { for (int i = 0; i < kernelRank; ++i) {
int64_t d = (dilationsArray[i]).cast<IntegerAttr>().getInt(); if (ArrayAttrIntVal(dilationsOpt, i) < 1)
if (d < 1) emitError("dialation value must be nonzero positive");
emitError("dialation value must be nonzero positive.");
actualDilations.emplace_back(d);
} }
} else { } else {
for(int i=0; i < kernelRank; ++i) { // Default dilatation is needed.
actualDilations.emplace_back(1); SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
} // Convert to ArrayRef, then build attribute, then store attribute.
ArrayRef<int64_t> defaultRefs(defaultVals);
auto defaultAttr = builder.getI64ArrayAttr(defaultRefs);
dilationsAttr(defaultAttr);
dilationsOpt = dilations();
} }
// storage order // Storage order.
auto storageOrder = storage_order().getSExtValue();
// strides if (storageOrder != 0)
SmallVector<int64_t, 4> actualStrides; emitError("column major storage order not supported at this time");
// Strides.
auto stridesOpt = strides(); auto stridesOpt = strides();
if (stridesOpt.hasValue()) { if (stridesOpt.hasValue()) {
auto stridesArray = stridesOpt.getValue().getValue(); if (ArrayAttrSize(stridesOpt) != kernelRank)
if (stridesArray.size() != kernelRank) emitError("strides rank is not the same as the spatial rank");
emitError("strides rank is not the same as the spatial rank."); // Check values.
// fill in the actual values
for (int i = 0; i < kernelRank; ++i) { for (int i = 0; i < kernelRank; ++i) {
int64_t s = (stridesArray[i]).cast<IntegerAttr>().getInt(); if (ArrayAttrIntVal(stridesOpt, i) < 1)
if (s < 1) emitError("strides value must be nonzero positive");
emitError("strides value must be nonzero positive.");
actualStrides.emplace_back(s);
} }
} else { } else {
for(int i=0; i < kernelRank; ++i) { SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
actualStrides.emplace_back(1); // Convert to ArrayRef, then build attribute, then store attribute.
} ArrayRef<int64_t> defaultRefs(defaultVals);
auto defaultAttr = builder.getI64ArrayAttr(defaultRefs);
stridesAttr(defaultAttr);
stridesOpt = strides();
} }
// now try to find padding, getting auto_pad attribute first // Now try to find padding, getting auto_pad attribute first.
auto autoPad = auto_pad(); auto autoPad = auto_pad();
// and then investigate the various different cases // And then investigate the various different cases.
SmallVector<int64_t, 4> actualPads; SmallVector<int64_t, 4> actualPads(2 * kernelRank, 0);
auto defaultPads = false;
if (autoPad == "NOTSET") { if (autoPad == "NOTSET") {
auto padsOpt = pads(); auto padsOpt = pads();
if (padsOpt.hasValue()) { if (padsOpt.hasValue()) {
auto padsArray = padsOpt.getValue().getValue(); // Pads consists of two entries for each spatial axis.
// pads consists of two entries for each spatial axis. if (ArrayAttrSize(padsOpt) != 2 * kernelRank)
if (padsArray.size() != 2 * kernelRank) emitError("pads rank is not twice the spatial rank");
emitError("pads rank is not twice the spatial rank."); // Check values
// fill in the actual values for (int i = 0; i < 2 * kernelRank; ++i) {
for (int i = 0; i < 2*kernelRank; ++i) { int64_t p = ArrayAttrIntVal(padsOpt, i);
int64_t p = (padsArray[i]).cast<IntegerAttr>().getInt(); if (p < 0)
if (p < 0) emitError("pads value must be nonnegative");
emitError("pads value must be nonnegative."); actualPads[i] = p;
actualPads.emplace_back(p);
} }
} else {
// pads are not defined, default to value 0
defaultPads = true;
} }
} else if (autoPad == "VALID") {
defaultPads = true;
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
// init pad with zero for (int i = 0; i < kernelRank; ++i) {
for(int i=0; i<2*kernelRank; ++i) { auto inputSpatialShape = xShape[kernelOffset + i];
actualPads.emplace_back(0); auto kernelSpatialShape = ArrayAttrIntVal(kernelShape, i);
} auto dilations = ArrayAttrIntVal(dilationsOpt, i);
for(int i=0; i<kernelRank; ++i) { auto strideSpatialShape = ArrayAttrIntVal(stridesOpt, i);
auto inputSpatialShape = xShape[kernelOffset + i]; int64_t outputSpatialShape =
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt(); ceil((1.0 * inputSpatialShape) / (1.0 * strideSpatialShape));
auto dilations = actualDilations[i]; auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +
auto strideSpatialShape = actualStrides[i]; ((kernelSpatialShape - 1) * dilations + 1) -
int64_t outputSpatialShape = ceil((1.0 * inputSpatialShape) / inputSpatialShape;
(1.0 * strideSpatialShape));
auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +
((kernelSpatialShape - 1) * dilations + 1) - inputSpatialShape;
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2; actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
if (sumOfPad % 2 != 0) { if (sumOfPad % 2 != 0) {
if (autoPad == "SAME_UPPER") { if (autoPad == "SAME_UPPER") {
actualPads[kernelRank + i] += 1; actualPads[kernelRank + i] += 1;
} else { } else {
actualPads[i] += 1; actualPads[i] += 1;
} }
} }
} }
} else { } else if (autoPad != "VALID") {
emitError("auto_pad of unknown / unsupported value."); emitError("auto_pad of unknown / unsupported value");
} }
// handle case where default pad values must be used // Set pads values in attributes.
if (defaultPads) { {
for(int i=0; i<2*kernelRank; ++i) { ArrayRef<int64_t> defaultRefs(actualPads);
actualPads.emplace_back(0); auto defaultAttr = builder.getI64ArrayAttr(defaultRefs);
} padsAttr(defaultAttr);
auto defaultAutoPadAttr = builder.getStringAttr("NOTSET");
auto_padAttr(defaultAutoPadAttr);
} }
// initialize output shape // Initialize output shape.
SmallVector<int64_t, 4> yShape(xShape.begin(), xShape.end()); SmallVector<int64_t, 4> yShape(xShape.begin(), xShape.end());
// for all kernel dimensions // Process for all kernel dimensions.
for(int i=0; i<kernelRank; ++i) { for (int i = 0; i < kernelRank; ++i) {
auto inputSpatialShape = xShape[kernelOffset + i]; auto inputSpatialShape = xShape[kernelOffset + i];
auto padShape = actualPads[i] + actualPads[kernelRank+i]; auto padShape = actualPads[i] + actualPads[kernelRank + i];
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt(); auto kernelSpatialShape = ArrayAttrIntVal(kernelShape, i);
auto dilations = actualDilations[i]; auto dilations = ArrayAttrIntVal(dilationsOpt, i);
auto strideSpatialShape = actualStrides[i]; auto strideSpatialShape = ArrayAttrIntVal(stridesOpt, i);
///output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] - double numerator = inputSpatialShape + padShape -
// ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1) ((kernelSpatialShape - 1) * dilations + 1);
double numerator = inputSpatialShape + padShape -
((kernelSpatialShape - 1) * dilations + 1);
double denominator = strideSpatialShape; double denominator = strideSpatialShape;
int64_t res; int64_t res;
if (ceilMode) { if (ceilMode) {
@ -1069,7 +1039,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
} }
yShape[kernelOffset + i] = res; yShape[kernelOffset + i] = res;
} }
auto arrayTy = getOperand().getType().cast<RankedTensorType>(); auto arrayTy = X().getType().cast<RankedTensorType>();
getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType())); getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType()));
} }
@ -1152,10 +1122,10 @@ void ONNXPadConstantValuePadOp::inferShapes(){
// Unsqueeze // Unsqueeze
void ONNXUnsqueezeOp::inferShapes() { void ONNXUnsqueezeOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) if (!data().getType().isa<RankedTensorType>())
return; return;
auto operandTy = getOperand().getType().cast<RankedTensorType>(); auto operandTy = data().getType().cast<RankedTensorType>();
int inRank = operandTy.getRank(); int inRank = operandTy.getRank();
ArrayAttr axisAttrs = axesAttr(); ArrayAttr axisAttrs = axesAttr();
@ -1171,10 +1141,10 @@ void ONNXUnsqueezeOp::inferShapes() {
if (std::find(axes.begin(), axes.end(), axis) == axes.end()) if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.emplace_back(axis); axes.emplace_back(axis);
else else
emitError("Duplicated axes."); emitError("Duplicated axes");
} }
} else { } else {
emitError("Axes attribute is required."); emitError("Axes attribute is required");
} }
SmallVector<int64_t, 4> dims; SmallVector<int64_t, 4> dims;

File diff suppressed because it is too large Load Diff

View File

@ -127,6 +127,10 @@ int main(int argc, char *argv[]) {
if (emissionTarget >= EmitMLIR) { if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass()); pm.addPass(mlir::createLowerToKrnlPass());
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createLowerKrnlPass()); pm.addPass(mlir::createLowerKrnlPass());
} }

View File

@ -28,6 +28,11 @@ void ONNXAddOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList& results, MLIRContext* context) {
results.insert<MulAddToGemmOptPattern>(context); results.insert<MulAddToGemmOptPattern>(context);
} }
void ONNXGemmOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<FuseGemmFollowedByAddition>(context);
}
/// on the ONNXIdentityOp. /// on the ONNXIdentityOp.
void ONNXIdentityOp::getCanonicalizationPatterns( void ONNXIdentityOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList& results, MLIRContext* context) {

View File

@ -26,6 +26,7 @@ include "dialect/onnx/onnx.td"
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>; def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>; class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite // Pattern-Match and Rewrite
@ -41,6 +42,11 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)), (ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>; [(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
// onnx.add(onnx.Gemm(%X, %Y, None), %Z) = onnx.Gemm(%X, %Y, %Z)
def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none, $alpha, $beta, $transA, $transB), $bias),
(ONNXGemmOp $m1, $m2, $bias, $alpha, $beta, $transA, $transB),
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2), (HasNoneType $none)]>;
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) // ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg), def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
(replaceWithValue $arg)>; (replaceWithValue $arg)>;

View File

@ -118,7 +118,6 @@ public:
op->getName().getStringRef() != "onnx.Identity" && op->getName().getStringRef() != "onnx.Identity" &&
op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.GemmNoBias" &&
op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose" && op->getName().getStringRef() != "onnx.Transpose" &&
op->getName().getStringRef() != "onnx.ReduceMax" && op->getName().getStringRef() != "onnx.ReduceMax" &&

View File

@ -101,3 +101,14 @@ func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> // CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
// CHECK-NEXT: return %1 : tensor<*xf32> // CHECK-NEXT: return %1 : tensor<*xf32>
} }
//CHECK-LABEL: @test_gemm_add_fusion(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> {
func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32>
%1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32>
// return [[GEMM]] : tensor<*xf32>
}

View File

@ -806,35 +806,6 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso
// CHECK: } // CHECK: }
} }
func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
%0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_gemm_no_bias
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
// CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
// CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
// CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
// CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) {
// CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32>
// CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
// CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
// CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: }
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
// CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<10x10xf32>
// CHECK: }
}
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()

View File

@ -6,7 +6,7 @@ func @test_default_maxpoolsingleout(%arg0 : tensor<5x5x32x32xf32>) -> tensor<*xf
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
} }
// CHECK-LABEL: test_default_maxpoolsingleout // CHECK-LABEL: test_default_maxpoolsingleout
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "VALID", ceil_mode = 0 : i64, kernel_shape = [3, 3], pads = [1, 1, 1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> // CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32>
// CHECK: return [[RES]] : tensor<5x5x30x30xf32> // CHECK: return [[RES]] : tensor<5x5x30x30xf32>
@ -16,7 +16,7 @@ func @test_default_maxpoolsingleout_defpad(%arg0 : tensor<5x5x32x32xf32>) -> ten
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
} }
// CHECK-LABEL: test_default_maxpoolsingleout_defpad // CHECK-LABEL: test_default_maxpoolsingleout_defpad
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [3, 3]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> // CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32>
// CHECK: return [[RES]] : tensor<5x5x30x30xf32> // CHECK: return [[RES]] : tensor<5x5x30x30xf32>
@ -26,7 +26,7 @@ func @test_default_maxpoolsingleout_pad(%arg0 : tensor<5x5x32x32xf32>) -> tensor
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
} }
// CHECK-LABEL: test_default_maxpoolsingleout_pad // CHECK-LABEL: test_default_maxpoolsingleout_pad
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [3, 3], pads = [1, 1, 1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> // CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32>
// CHECK: return [[RES]] : tensor<5x5x32x32xf32> // CHECK: return [[RES]] : tensor<5x5x32x32xf32>
@ -36,7 +36,7 @@ func @test_default_maxpoolsingleout_pad_nonunif(%arg0 : tensor<5x5x32x32xf32>) -
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
} }
// CHECK-LABEL: test_default_maxpoolsingleout_pad_nonunif // CHECK-LABEL: test_default_maxpoolsingleout_pad_nonunif
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [2, 1, 1, 0]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x31x31xf32> // CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [5, 3], pads = [2, 1, 1, 0], strides = [1, 1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x31x31xf32>
// CHECK: return [[RES]] : tensor<5x5x31x31xf32> // CHECK: return [[RES]] : tensor<5x5x31x31xf32>
@ -46,7 +46,7 @@ func @test_default_maxpoolsingleout_strides(%arg0 : tensor<5x5x32x32xf32>) -> te
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
} }
// CHECK-LABEL: test_default_maxpoolsingleout_strides // CHECK-LABEL: test_default_maxpoolsingleout_strides
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x16x16xf32> // CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x16x16xf32>
// CHECK: return [[RES]] : tensor<5x5x16x16xf32> // CHECK: return [[RES]] : tensor<5x5x16x16xf32>
@ -56,7 +56,7 @@ func @test_default_maxpoolsingleout_strides_nonunifpad(%arg0 : tensor<5x5x30x32x
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
} }
// CHECK-LABEL: test_default_maxpoolsingleout_strides_nonunifpad // CHECK-LABEL: test_default_maxpoolsingleout_strides_nonunifpad
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32> // CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32>
// CHECK: return [[RES]] : tensor<5x5x15x16xf32> // CHECK: return [[RES]] : tensor<5x5x15x16xf32>
@ -66,7 +66,7 @@ func @test_default_maxpoolsingleout_strides_nonunifpad_ceil(%arg0 : tensor<5x5x3
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
} }
// CHECK-LABEL: test_default_maxpoolsingleout_strides_nonunifpad_ceil // CHECK-LABEL: test_default_maxpoolsingleout_strides_nonunifpad_ceil
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 1 : i64, kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x16x16xf32> // CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 1 : i64, dilations = [1, 1], kernel_shape = [2, 2], pads = [1, 0, 0, 0], strides = [2, 2]} : (tensor<5x5x30x32xf32>) -> tensor<5x5x16x16xf32>
// CHECK: return [[RES]] : tensor<5x5x16x16xf32> // CHECK: return [[RES]] : tensor<5x5x16x16xf32>
@ -76,7 +76,7 @@ func @test_default_maxpoolsingleout_strides_dilatation(%arg0 : tensor<5x5x8x8xf3
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
} }
// CHECK-LABEL: test_default_maxpoolsingleout_strides_dilatation // CHECK-LABEL: test_default_maxpoolsingleout_strides_dilatation
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [2, 2], kernel_shape = [2, 2], strides = [3, 3]} : (tensor<5x5x8x8xf32>) -> tensor<5x5x2x2xf32> // CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [2, 2], kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [3, 3]} : (tensor<5x5x8x8xf32>) -> tensor<5x5x2x2xf32>
// CHECK: return [[RES]] : tensor<5x5x2x2xf32> // CHECK: return [[RES]] : tensor<5x5x2x2xf32>
/// Test the default behavior of Max Pool with dilatation /// Test the default behavior of Max Pool with dilatation
@ -85,7 +85,7 @@ func @test_default_maxpoolsingleout_upper(%arg0 : tensor<5x5x16x13xf32>) -> tens
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
} }
// CHECK-LABEL: test_default_maxpoolsingleout_upper // CHECK-LABEL: test_default_maxpoolsingleout_upper
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "SAME_UPPER", ceil_mode = 0 : i64, kernel_shape = [4, 4], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32> // CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [4, 4], pads = [0, 1, 0, 2], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32>
// CHECK: return [[RES]] : tensor<5x5x4x4xf32> // CHECK: return [[RES]] : tensor<5x5x4x4xf32>
@ -95,6 +95,6 @@ func @test_default_maxpoolsingleout_lower(%arg0 : tensor<5x5x16x13xf32>) -> tens
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
} }
// CHECK-LABEL: test_default_maxpoolsingleout_lower // CHECK-LABEL: test_default_maxpoolsingleout_lower
// CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "SAME_LOWER", ceil_mode = 0 : i64, kernel_shape = [4, 4], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32> // CHECK: [[RES:%.+]] = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, dilations = [1, 1], kernel_shape = [4, 4], pads = [0, 2, 0, 1], strides = [4, 4]} : (tensor<5x5x16x13xf32>) -> tensor<5x5x4x4xf32>
// CHECK: return [[RES]] : tensor<5x5x4x4xf32> // CHECK: return [[RES]] : tensor<5x5x4x4xf32>

2
third_party/onnx vendored

@ -1 +1 @@
Subproject commit 1439eab5542c625bb3da49860f0cd68c3eafdc18 Subproject commit 553df22c67bee5f0fe6599cff60f1afc6748c635